Compare commits

...

107 Commits

Author SHA1 Message Date
CaIon
a29f4d88c5 Update model-ratio.go 2025-04-04 23:41:41 +08:00
CaIon
a6bb30af41 fix: Improve setup check logic and logging for system initialization 2025-04-04 21:27:24 +08:00
CaIon
424424c160 Update model-ratio.go 2025-04-04 00:31:24 +08:00
CaIon
e5baa6ee1c feat: Enhance ModelSettingsVisualEditor with pricing modes and improved model management features 2025-04-03 20:42:08 +08:00
CaIon
9207d729ca feat: Add new localization strings for system initialization 2025-04-03 19:27:25 +08:00
CaIon
27933da884 fix: Update option key from SelfUseModeEnabled to DemoSiteEnabled in PostSetup function 2025-04-03 19:21:53 +08:00
CaIon
454dac17ea feat: Add timestamp and version to setup initialization in PostSetup function 2025-04-03 19:16:17 +08:00
CaIon
1921ac3692 fix: Correct option key for SelfUseModeEnabled in setup controller 2025-04-03 19:15:04 +08:00
CaIon
42a2418d9a Merge remote-tracking branch 'origin/main' 2025-04-03 19:09:26 +08:00
CaIon
5cb317bdbd Update README.md 2025-04-03 19:09:13 +08:00
Calcium-Ion
37dd1ef099 Merge pull request #925 from Calcium-Ion/setup
 feat: Implement system setup functionality
2025-04-03 19:01:45 +08:00
CaIon
5fa6462412 feat: Refine personal mode description in setup page for clarity 2025-04-03 19:01:16 +08:00
CaIon
a882e680ae feat: Implement system setup functionality 2025-04-03 18:57:15 +08:00
CaIon
552e2850c5 Merge remote-tracking branch 'origin/main' 2025-04-03 17:33:03 +08:00
CaIon
c418d9ed9a feat: Enhance user settings and notification options 2025-04-03 17:32:48 +08:00
Calcium-Ion
1dc2284d57 Merge pull request #909 from jasinliu/feature/fix-dify-thinking
feat: fix dify thinking
2025-04-03 16:23:12 +08:00
Calcium-Ion
f4cc90c8d6 Merge pull request #893 from wizcas/replace-linux-do-icon
替换登录界面的 Linux.do OAuth 图标
2025-03-31 22:38:41 +08:00
Calcium-Ion
140d3a974b Merge pull request #895 from Feiyuyu0503/main
docs: fix a typo
2025-03-31 22:38:25 +08:00
Calcium-Ion
2ecb742e47 Merge pull request #912 from OrdinarySF/main
fix: fixed bug where target.id was null when clicking 'x' icon
2025-03-31 22:38:08 +08:00
Calcium-Ion
9066cfa8a0 Merge pull request #914 from JoeyLearnsToCode/main
feat: Add Parameters Override
2025-03-31 22:37:26 +08:00
Calcium-Ion
4f437f30e0 Merge pull request #916 from xifan2333/fix/systemSettingsUI
 feat: Update option handling in SystemSetting
2025-03-31 22:36:14 +08:00
xifan
3c2a86f94d feat: Update option handling in SystemSetting
-  Add backend validation for OIDC & Telegram OAuth config
- ♻️ Refactor frontend option updates with batch processing
2025-03-31 00:46:13 +08:00
JoeyLearnsToCode
1b07282153 feat: Add Parameters Override 2025-03-29 14:39:39 +08:00
Ordinary
af7f886c39 refactor: use handleFieldChange function on change event 2025-03-28 12:44:40 +00:00
Ordinary
9cfa138796 fix: fixed bug where target.id was null when clicking 'x' icon 2025-03-28 12:43:26 +00:00
jasinliu
dc132655a6 fix dify thinking 2025-03-28 00:21:27 +08:00
1808837298@qq.com
a378665b8c feat: Add new cache ratios for o3-mini and gpt-4.5-preview models 2025-03-27 18:47:50 +08:00
1808837298@qq.com
3516aad349 update model ratio 2025-03-27 17:02:09 +08:00
1808837298@qq.com
58525c574b feat: Enhance GetCompletionRatio function 2025-03-27 16:38:29 +08:00
1808837298@qq.com
1df39e5a7f update model ratio 2025-03-27 16:24:30 +08:00
feiyuyu
be6ffd3c60 docs: fix a typo 2025-03-22 21:28:25 +08:00
Wizcas Chen
a9522075c6 replace the linuxdo icon in the login form 2025-03-22 17:16:07 +08:00
Calcium-Ion
983d31bfd3 Merge pull request #886 from seefs001/main
fix: claude function calling type
2025-03-20 23:22:20 +08:00
Seefs
20c043f584 fix: claude function calling type 2025-03-19 22:48:49 +08:00
1808837298@qq.com
73263e02d6 fix: Adjust MaxTokens logic for non-Claude models in test request 2025-03-17 23:44:32 +08:00
1808837298@qq.com
7143b0f160 feat: Add support for cross-region AWS model handling in awsStreamHandler 2025-03-17 23:41:00 +08:00
1808837298@qq.com
dd82618c05 refactor: Improve token quota consumption logic 2025-03-17 17:52:54 +08:00
1808837298@qq.com
19935ee8ac feat: Enhance ConvertClaudeRequest method to set request model and handle vertex-specific request conversion 2025-03-17 17:13:33 +08:00
1808837298@qq.com
6fef5aaf22 feat: Update RerankerInfo structure and modify GenRelayInfoRerank function to accept RerankRequest 2025-03-17 16:44:53 +08:00
Calcium-Ion
b5aa3c129b Merge pull request #872 from neotf/main
feat: support AWS Model CrossRegion
2025-03-17 16:18:11 +08:00
1808837298@qq.com
8c7c39550c refactor: Update ClaudeResponse error handling to use pointer for ClaudeError and improve nil checks in response processing 2025-03-16 23:14:45 +08:00
1808837298@qq.com
962e803d8a Update README 2025-03-16 21:53:00 +08:00
1808837298@qq.com
ff57ced2bb Update README 2025-03-16 21:47:32 +08:00
1808837298@qq.com
2223806c00 Update README 2025-03-16 21:17:08 +08:00
1808837298@qq.com
d1c62a583d feat: support xinference rerank to jina format 2025-03-16 21:06:29 +08:00
1808837298@qq.com
53b3599827 refactor: Enhance Claude response handling 2025-03-16 19:11:58 +08:00
1808837298@qq.com
b3b1c803fc feat: Introduce JSON decoding utility functions and update error handling in Claude and OpenAI response structures 2025-03-16 18:34:39 +08:00
1808837298@qq.com
a4a40c495d Merge remote-tracking branch 'origin/main' 2025-03-16 16:48:15 +08:00
1808837298@qq.com
ee302c063c refactor: Enhance error handling in AWS and Claude response processing by updating function signatures and improving error propagation 2025-03-16 16:47:16 +08:00
Calcium-Ion
5a67bdf1b0 Merge pull request #851 from HynoR/main
Fix: 修正DeepSeek缓存倍率
2025-03-16 16:31:48 +08:00
1808837298@qq.com
2c81a5f0cc refactor: Streamline AWS and Claude response handling by consolidating logic and improving error management 2025-03-16 16:07:51 +08:00
Calcium-Ion
b84b6affe9 Merge pull request #874 from HynoR/feat/gemini2
Chore: Sync Cohere Latest Model
2025-03-15 19:44:37 +08:00
1808837298@qq.com
c183c1231c refactor: Replace direct access to ImageUrl with GetImageMedia method across multiple relay channels 2025-03-15 19:43:37 +08:00
1808837298@qq.com
54e738941d feat: Add warning modal for base URL input and display warning banner for specific channel type in EditChannel component 2025-03-15 19:38:05 +08:00
1808837298@qq.com
dd393cd0d9 feat: support dify upload image file 2025-03-15 19:10:12 +08:00
TAKO
e98849048c Sync Cohere Latest Model 2025-03-15 12:12:46 +08:00
TAKO
8e68bcce29 Merge branch 'main' into main 2025-03-15 12:08:44 +08:00
neotf
892d014c26 feat: support AWS Model CrossRegion 2025-03-15 01:42:24 +08:00
1808837298@qq.com
19bfa158cc refactor: Change ClaudeError field type to non-pointer and enhance response handling with reasoning content 2025-03-14 17:48:26 +08:00
CalciumIon
69e44a03b1 refactor: Simplify OpenAI handler function signature and remove unused TextResponseWithError struct; introduce common_handler for rerank functionality 2025-03-14 17:31:05 +08:00
CalciumIon
9a78db8484 feat: Add HasSentThinkingContent field to ThinkingContentInfo struct 2025-03-14 17:09:40 +08:00
Calcium-Ion
a381163402 Merge pull request #867 from Sh1n3zZ/wrong-think-label-fix
fix: wrong thinking labels appear in non-thinking models (#861)
2025-03-14 16:59:56 +08:00
CalciumIon
1644dbc864 refactor: Update token usage calculation in FormatClaudeResponseInfo #865 2025-03-14 17:00:39 +08:00
Sh1n3zZ
cc1400e939 fix: wrong thinking labels appear in non-thinking models (#861) 2025-03-14 03:13:52 +08:00
1808837298@qq.com
6187656aa9 chore: Update GitHub Actions workflows and refactor adaptor logic for Docker image builds 2025-03-13 21:10:39 +08:00
Calcium-Ion
e5b6aa6e85 Merge pull request #857 from asjfoajs/main
Refactor: Optimize the ImageHandler under the Alibaba large model to …
2025-03-13 19:51:08 +08:00
1808837298@qq.com
7e46d4217d feat: 初步兼容流模式下openai渠道类型转为claude格式访问 #862 2025-03-13 19:32:08 +08:00
霍雨佳
23596d22c9 Refactor: Optimize the ImageHandler under the Alibaba large model to retrieve the key from the header.
Reason: The info parameter already includes the key, so there is no need to retrieve it again from the header.
Solution: Delete the code for obtaining the key and directly use info.ApiKey.
2025-03-13 08:54:45 +08:00
Calcium-Ion
c25d4d8d23 Update README.md 2025-03-12 22:22:21 +08:00
Calcium-Ion
b291fbff6b Update README.md 2025-03-12 22:13:35 +08:00
Calcium-Ion
e68edf81f7 Update README.md 2025-03-12 22:12:09 +08:00
Calcium-Ion
5ff16f9b2d Merge pull request #854 from seefs001/main
feat: Support postgresql:// dsn format
2025-03-12 21:36:30 +08:00
Calcium-Ion
f614cfa563 Merge pull request #855 from Calcium-Ion/claude
feat: claude relay
2025-03-12 21:36:11 +08:00
1808837298@qq.com
2048b451bf fix panic 2025-03-12 21:35:57 +08:00
1808837298@qq.com
bd48f43410 feat: claude relay 2025-03-12 21:31:46 +08:00
Seefs
c47d8a10f0 feat: Support postgresql:// dsn format 2025-03-12 21:08:47 +08:00
1808837298@qq.com
c0b9350785 fix: claude to openai tools use 2025-03-12 19:46:08 +08:00
1808837298@qq.com
229738cda9 fix: claude to openai tools use 2025-03-12 19:29:15 +08:00
1808837298@qq.com
39d95172e8 fix: claude to openai tools use 2025-03-12 18:53:38 +08:00
1808837298@qq.com
5059cbdb46 Merge remote-tracking branch 'origin/main' 2025-03-12 17:53:52 +08:00
1808837298@qq.com
a981e10712 feat(relay): Add Xinference channel support 2025-03-12 17:53:46 +08:00
TAKO
f7852ada97 Fix Deepseek Cache Ratio 2025-03-12 10:51:12 +08:00
Calcium-Ion
495bbcb621 Merge pull request #848 from wzxjohn/feature/oidc
feat: add oidc support
2025-03-11 23:20:55 +08:00
1808837298@qq.com
20e34bec7e fix: Add error logging for OIDC configuration retrieval 2025-03-11 23:20:27 +08:00
1808837298@qq.com
0033f5ba2e refactor: Update OIDC status check to use oidc_enabled flag 2025-03-11 22:36:31 +08:00
1808837298@qq.com
e52ac52e7b refactor: Remove OIDC configuration from option initialization 2025-03-11 22:03:20 +08:00
1808837298@qq.com
66682584a5 refactor: Migrate OIDC configuration to system settings 2025-03-11 22:00:31 +08:00
1808837298@qq.com
1a2bf8df1f feat(ui): Improve model testing button layout and styling 2025-03-11 21:22:10 +08:00
1808837298@qq.com
1819c4d5f5 feat(error): Enhance error handling with optional detailed error messages 2025-03-11 17:25:06 +08:00
1808837298@qq.com
6f24dddcb2 feat(relay): Add pass-through request option for global settings 2025-03-11 17:02:35 +08:00
1808837298@qq.com
8de29fbb83 Merge remote-tracking branch 'origin/main' 2025-03-11 16:41:18 +08:00
Calcium-Ion
f2163acf2b Merge pull request #849 from OrdinarySF/main
feat(setting): add 'Document Link' option i18n support
2025-03-11 16:27:37 +08:00
Ordinary
5259acfacd feat(setting): add 'Document Link' option i18n support 2025-03-11 08:22:59 +00:00
wzxjohn
c433af284c feat: add oidc support 2025-03-11 15:52:03 +08:00
1808837298@qq.com
3122b8a36a fix: Improve mobile layout and scrolling behavior 2025-03-11 15:05:23 +08:00
1808837298@qq.com
bbe7223a85 Merge remote-tracking branch 'origin/main' 2025-03-11 14:55:56 +08:00
1808837298@qq.com
2af05c166c feat: Improve route handling and dynamic chat navigation in SiderBar 2025-03-11 14:55:48 +08:00
Calcium-Ion
ecb5b5630c Merge pull request #845 from Sh1n3zZ/gemini-embedding
feat: gemini Embeddings support
2025-03-10 23:46:53 +08:00
Sh1n3zZ
e1b9f164f9 feat: gemini Embeddings support 2025-03-10 23:32:06 +08:00
1808837298@qq.com
69db1f1465 Merge remote-tracking branch 'origin/main' 2025-03-10 21:05:43 +08:00
1808837298@qq.com
94549f9687 refactor: Improve responsive design across multiple setting pages 2025-03-10 21:05:22 +08:00
Calcium-Ion
c7e1bab18a Merge pull request #842 from asjfoajs/dev
Fix: Under Ali's large model, the task ID result for image retrieval …
2025-03-10 20:18:53 +08:00
1808837298@qq.com
627f95b034 refactor: Remove unnecessary transition styles and simplify sidebar state management 2025-03-10 20:14:23 +08:00
1808837298@qq.com
8b99eec440 refactor: Improve sidebar state management and layout responsiveness 2025-03-10 19:48:17 +08:00
1808837298@qq.com
49bfd2b719 feat: Enhance mobile UI responsiveness and layout for ChannelsTable and SiderBar 2025-03-10 19:01:56 +08:00
霍雨佳
434e9d7695 Fix: Under Ali's large model, the task ID result for image retrieval is incorrect.
Reason: The URL is incomplete, missing baseurl.
Solution: Add baseurl. url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID).
2025-03-10 16:22:40 +08:00
1808837298@qq.com
b2938ffe2c refactor: Improve mobile responsiveness and scrolling behavior in UI layout 2025-03-10 15:49:32 +08:00
146 changed files with 9870 additions and 6706 deletions

View File

@@ -18,20 +18,20 @@ jobs:
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -39,14 +39,14 @@ jobs:
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
uses: docker/metadata-action@v5
with:
images: |
calciumion/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
uses: docker/build-push-action@v5
with:
context: .
push: true

View File

@@ -4,7 +4,6 @@ on:
push:
tags:
- '*'
- '!*-alpha*'
workflow_dispatch:
inputs:
name:
@@ -19,26 +18,26 @@ jobs:
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Save version info
run: |
git describe --tags > VERSION
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -46,14 +45,14 @@ jobs:
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
uses: docker/metadata-action@v5
with:
images: |
calciumion/new-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64

222
README.md
View File

@@ -7,7 +7,6 @@
# New API
🍥新一代大模型网关与AI资产管理系统
<a href="https://trendshift.io/repositories/8227" target="_blank"><img src="https://trendshift.io/api/badge/repositories/8227" alt="Calcium-Ion%2Fnew-api | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
@@ -37,199 +36,154 @@
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发
> [!IMPORTANT]
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> - 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。
> - 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> - 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
## 📚 文档
详细文档请访问我们的官方Wiki[https://docs.newapi.pro/](https://docs.newapi.pro/)
## ✨ 主要特性
1. 🎨 全新的UI界面部分界面还待更新
2. 🌍 多语言支持(待完善)
3. 🎨 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口支持,[对接文档](Midjourney.md)
4. 💰 支持在线充值功能,可在系统设置中设置:
- [x] 易支付
5. 🔍 支持用key查询使用额度
- 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用
6. 📑 分页支持选择每页显示数量
7. 🔄 兼容原版One API的数据库可直接使用原版数据库one-api.db
8. 💵 支持模型按次数收费,可在 系统设置-运营设置 中设置
9. ⚖️ 支持渠道**加权随机**
10. 📈 数据看板(控制台
11. 🔒 可设置令牌能调用的模型
12. 🤖 支持Telegram授权登录
1. 系统设置-配置登录注册-允许通过Telegram登录
2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
3. 选择你的bot然后输入http(s)://你的网站地址/login
4. Telegram Bot 名称是bot username 去掉@后的字符串
13. 🎵 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口支持,[对接文档](Suno.md)
14. 🔄 支持Rerank模型目前兼容Cohere和Jina可接入Dify[对接文档](Rerank.md)
15.**[OpenAI Realtime API](https://platform.openai.com/docs/guides/realtime/integration)** - 支持OpenAI的Realtime API支持Azure渠道
16. 支持使用路由/chat2link 进入聊天界面
17. 🧠 支持通过模型名称后缀设置 reasoning effort
New API提供了丰富的功能详细特性请参考[特性说明](https://docs.newapi.pro/wiki/features-introduction)
1. 🎨 全新的UI界面
2. 🌍 多语言支持
3. 💰 支持在线充值功能(易支付
4. 🔍 支持用key查询使用额度(配合[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)
5. 🔄 兼容原版One API的数据库
6. 💵 支持模型按次数收费
7. ⚖️ 支持渠道加权随机
8. 📈 数据看板(控制台)
9. 🔒 令牌分组、模型限制
10. 🤖 支持更多授权登陆方式LinuxDO,Telegram、OIDC
11. 🔄 支持Rerank模型Cohere和Jina[接口文档](https://docs.newapi.pro/api/jinaai-rerank)
12. 支持OpenAI Realtime API包括Azure渠道[接口文档](https://docs.newapi.pro/api/openai-realtime)
13. ⚡ 支持Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
14. 支持使用路由/chat2link进入聊天界面
15. 🧠 支持通过模型名称后缀设置 reasoning effort
1. OpenAI o系列模型
- 添加后缀 `-high` 设置为 high reasoning effort (例如: `o3-mini-high`)
- 添加后缀 `-medium` 设置为 medium reasoning effort (例如: `o3-mini-medium`)
- 添加后缀 `-low` 设置为 low reasoning effort (例如: `o3-mini-low`)
2. Claude 思考模型
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
18. 🔄 思考转内容,支持在 `渠道-编辑-渠道额外设置` 中设置 `thinking_to_content` 选项,默认`false`,开启后会将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回。
19. 🔄 模型限流,支持在 `系统设置-速率限制设置` 中设置模型限流,支持设置总请求数限制和成功请求数限制
20. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
16. 🔄 思考转内容功能
17. 🔄 针对用户的模型限流功能
18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
1.`系统设置-运营设置` 中设置 `提示缓存倍率` 选项
2. 在渠道中设置 `提示缓存倍率`,范围 0-1例如设置为 0.5 表示缓存命中时按照 50% 计费
3. 支持的渠道:
- [x] OpenAI
- [x] Azure
- [x] DeepSeek
- [ ] Claude
- [x] Claude
## 模型支持
此版本额外支持以下模型:
此版本支持多种模型,详情请参考[接口文档-中继接口](https://docs.newapi.pro/api)
1. 第三方模型 **gpts** gpt-4-gizmo-*
2. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[接文档](Midjourney.md)
3. 自定义渠道,支持填入完整调用地址
4. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
5. Rerank模型,目前支持[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[接文档](Rerank.md)
6. Dify
2. 第三方渠道[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[文档](https://docs.newapi.pro/api/midjourney-proxy-image)
3. 第三方渠道[Suno API](https://github.com/Suno-API/Suno-API)接口,[接口文档](https://docs.newapi.pro/api/suno-music)
4. 自定义渠道,支持填入完整调用地址
5. Rerank模型[Cohere](https://cohere.ai/)和[Jina](https://jina.ai/)[文档](https://docs.newapi.pro/api/jinaai-rerank)
6. Claude Messages 格式,[接口文档](https://docs.newapi.pro/api/anthropic-chat)
7. Dify当前仅支持chatflow
您可以在渠道中添加自定义模型gpt-4-gizmo-*此模型并非OpenAI官方模型而是第三方模型使用官方key无法调用。
## 环境变量配置
## 比原版One API多出的配置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 60 秒。
- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数请求上游返回流模式usage默认为 `true`建议开启不影响客户端传入stream_options参数返回结果。
- `GET_MEDIA_TOKEN`是否统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本如果渠道设置中未指定API版本则使用此版本默认为 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables)
## 已废弃的环境变量
- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
- `STREAMING_TIMEOUT`流式回复超时时间默认60秒
- `DIFY_DEBUG`Dify渠道是否输出工作流和节点信息默认 `true`
- `FORCE_STREAM_OPTION`是否覆盖客户端stream_options参数默认 `true`
- `GET_MEDIA_TOKEN`是否统计图片token默认 `true`
- `GET_MEDIA_TOKEN_NOT_STREAM`非流情况下是否统计图片token默认 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认 `true`
- `COHERE_SAFETY_SETTING`Cohere模型安全设置可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认 `16`
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小单位MB默认 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容
- `AZURE_DEFAULT_API_VERSION`Azure渠道默认API版本默认 `2024-12-01-preview`
- `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制持续时间,默认 `10`分钟
- `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认 `2`
## 部署
详细部署指南请参考[安装指南-部署方式](https://docs.newapi.pro/installation)
> [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest`
> 默认账号root 密码123456
### 多机部署
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
- 如果公用Redis必须设置 `CRYPTO_SECRET`否则会导致多机部署时Redis内容无法获取
### 多机部署注意事项
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致
- 如果公用Redis必须设置 `CRYPTO_SECRET`否则会导致多机部署时Redis内容无法获取
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
- 本地数据库默认SQLiteDocker部署必须挂载`/data`目录)
- 远程数据库MySQL版本 >= 5.7.8PgSQL版本 >= 9.6
### 使用宝塔面板Docker功能部署
安装宝塔面板 (**9.2.0版本**及以上),前往 [宝塔面板](https://www.bt.cn/new/download.html) 官网,选择正式版的脚本下载安装
安装后登录宝塔面板,在菜单栏中点击 Docker ,首次进入会提示安装 Docker 服务,点击立即安装,按提示完成安装
安装完成后在应用商店中找到 **New-API** ,点击安装,配置基本选项 即可完成安装
### 部署方式
#### 使用宝塔面板Docker功能部署
安装宝塔面板(**9.2.0版本**及以上),在应用商店中找到**New-API**安装即可。
[图文教程](BT.md)
### 基于 Docker 进行部署
> [!TIP]
> 默认管理员账号root 密码123456
### 使用 Docker Compose 部署(推荐)
#### 使用Docker Compose部署推荐
```shell
# 下载项目
git clone https://github.com/Calcium-Ion/new-api.git
cd new-api
# 按需编辑 docker-compose.yml
# nano docker-compose.yml
# vim docker-compose.yml
# 按需编辑docker-compose.yml
# 启动
docker-compose up -d
```
#### 更新版本
#### 直接使用Docker镜像
```shell
docker-compose pull
docker-compose up -d
```
### 直接使用 Docker 镜像
```shell
# 使用 SQLite 的部署命令:
# 使用SQLite
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
# 使用MySQL
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
#### 更新版本
```shell
# 拉取最新镜像
docker pull calciumion/new-api:latest
# 停止并删除旧容器
docker stop new-api
docker rm new-api
# 使用相同参数运行新容器
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
## 渠道重试与缓存
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
或者使用 Watchtower 自动更新(不推荐,可能会导致数据库不兼容):
```shell
docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
```
## 渠道重试
渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。
如果开启了重试功能,重试使用下一个优先级,以此类推。
### 缓存设置方法
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true``false`,未设置则默认为 `false`
+ 例子:`MEMORY_CACHE_ENABLED=true`
### 为什么有的时候没有重试
这些错误码不会重试400504524
### 我想让400也重试
`渠道->编辑`中,将`状态码复写`改为
```json
{
"400": "500"
}
```
可以实现400错误转为500错误从而重试
1. `REDIS_CONN_STRING`设置Redis作为缓存
2. `MEMORY_CACHE_ENABLED`启用内存缓存设置了Redis则无需手动设置
## Midjourney接口设置文档
[对接文档](Midjourney.md)
## 接口文档
## Suno接口设置文档
[对接文档](Suno.md)
详细接口文档请参考[接口文档](https://docs.newapi.pro/api)
## 界面截图
![image](https://github.com/user-attachments/assets/a0dcd349-5df8-4dc8-9acf-ca272b239919)
![image](https://github.com/user-attachments/assets/c7d0f7e1-729c-43e2-ac7c-2cb73b0afc8e)
![image](https://github.com/user-attachments/assets/29f81de5-33fc-4fc5-a5ff-f9b54b653c7c)
![image](https://github.com/user-attachments/assets/4fa53e18-d2c5-477a-9b26-b86e44c71e35)
## 交流群
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">
- [聊天接口Chat](https://docs.newapi.pro/api/openai-chat)
- [图像接口Image](https://docs.newapi.pro/api/openai-image)
- [重排序接口Rerank](https://docs.newapi.pro/api/jinaai-rerank)
- [实时对话接口Realtime](https://docs.newapi.pro/api/openai-realtime)
- [Claude聊天接口messages](https://docs.newapi.pro/api/anthropic-chat)
## 相关项目
- [One API](https://github.com/songquanpeng/one-api):原版项目
- [Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)Midjourney接口支持
- [chatnio](https://github.com/Deeptrain-Community/chatnio):下一代 AI 一站式 B/C 端解决方案
- [chatnio](https://github.com/Deeptrain-Community/chatnio)下一代AI一站式B/C端解决方案
- [neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)用key查询使用额度
其他基于New API的项目
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版专注于高并发优化并支持Claude格式
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本,闭源免费
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon)New API高性能优化版
- [VoAPI](https://github.com/VoAPI/VoAPI)基于New API的前端美化版本
## 帮助支持
如有问题,请参考[帮助支持](https://docs.newapi.pro/support)
- [社区交流](https://docs.newapi.pro/support/community-interaction)
- [反馈问题](https://docs.newapi.pro/support/feedback-issues)
- [常见问题](https://docs.newapi.pro/support/faq)
## 🌟 Star History

View File

@@ -77,7 +77,6 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LinuxDOClientId = ""
var LinuxDOClientSecret = ""
@@ -235,6 +234,7 @@ const (
ChannelTypeMokaAI = 44
ChannelTypeVolcEngine = 45
ChannelTypeBaiduV2 = 46
ChannelTypeXinference = 47
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -287,4 +287,5 @@ var ChannelBaseURLs = []string{
"https://api.moka.ai", //44
"https://ark.cn-beijing.volces.com", //45
"https://qianfan.baidubce.com", //46
"", //47
}

View File

@@ -44,7 +44,7 @@ var fieldReplacer = strings.NewReplacer(
"\r", "\\r")
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\n", "\n",
"\r", "\\r")
type CustomEvent struct {

14
common/json.go Normal file
View File

@@ -0,0 +1,14 @@
package common
import (
"bytes"
"encoding/json"
)
func DecodeJson(data []byte, v any) error {
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
}
func DecodeJsonStr(data string, v any) error {
return DecodeJson(StringToByteSlice(data), v)
}

3
constant/setup.go Normal file
View File

@@ -0,0 +1,3 @@
package constant
var Setup = false

View File

@@ -1,11 +1,12 @@
package constant
var (
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
)
var (

View File

@@ -105,9 +105,14 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
request := buildTestRequest(testModel)
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %v ", channel.Id, testModel, info))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
if err != nil {
return err, nil
}
adaptor.Init(info)
convertedRequest, err := adaptor.ConvertRequest(c, info, request)
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
if err != nil {
return err, nil
}
@@ -125,7 +130,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(httpResp)
err := service.RelayErrorHandler(httpResp, true)
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
}
}
@@ -143,10 +148,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
return err, nil
}
info.PromptTokens = usage.PromptTokens
priceData, err := helper.ModelPriceHelper(c, info, usage.PromptTokens, int(request.MaxTokens))
if err != nil {
return err, nil
}
quota := 0
if !priceData.UsePrice {
quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
@@ -187,7 +189,9 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
if strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") {
testRequest.MaxCompletionTokens = 10
} else if strings.Contains(model, "thinking") {
testRequest.MaxTokens = 50
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50
}
} else {
testRequest.MaxTokens = 10
}

9
controller/image.go Normal file
View File

@@ -0,0 +1,9 @@
package controller
import (
"github.com/gin-gonic/gin"
)
func GetImage(c *gin.Context) {
}

View File

@@ -5,9 +5,11 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting"
"one-api/setting/operation_setting"
"one-api/setting/system_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -34,40 +36,44 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"version": common.Version,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"telegram_oauth": common.TelegramOAuthEnabled,
"telegram_bot_name": common.TelegramBotName,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_task": common.TaskEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
"demo_site_enabled": operation_setting.DemoSiteEnabled,
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
"setup": constant.Setup,
},
})
return

240
controller/oidc.go Normal file
View File

@@ -0,0 +1,240 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/system_setting"
"strconv"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
if oidcResponse.AccessToken == "" {
common.SysError("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysError("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -6,6 +6,7 @@ import (
"one-api/common"
"one-api/model"
"one-api/setting"
"one-api/setting/system_setting"
"strings"
"github.com/gin-gonic/gin"
@@ -51,6 +52,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "oidc.enabled":
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 OIDC 登录,请先填入 OIDC Client Id 以及 OIDC Client Secret",
})
return
}
case "LinuxDOOAuthEnabled":
if option.Value == "true" && common.LinuxDOClientId == "" {
c.JSON(http.StatusOK, gin.H{
@@ -81,6 +90,15 @@ func UpdateOption(c *gin.Context) {
"success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",
})
return
}
case "TelegramOAuthEnabled":
if option.Value == "true" && common.TelegramBotToken == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Telegram OAuth请先填入 Telegram Bot Token",
})
return
}
case "GroupRatio":
@@ -92,6 +110,7 @@ func UpdateOption(c *gin.Context) {
})
return
}
}
err = model.UpdateOption(option.Key, option.Value)
if err != nil {

View File

@@ -148,6 +148,50 @@ func WssRelay(c *gin.Context) {
}
}
func RelayClaude(c *gin.Context) {
//relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
originalModel := c.GetString("original_model")
var claudeErr *dto.ClaudeErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
claudeErr = claudeRequest(c, channel)
if claudeErr == nil {
return // 成功处理请求,直接返回
}
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if claudeErr != nil {
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
c.JSON(claudeErr.StatusCode, gin.H{
"type": "error",
"error": claudeErr.Error,
})
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
@@ -162,6 +206,13 @@ func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *mode
return relay.WssHelper(c, ws)
}
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relay.ClaudeHelper(c)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))

173
controller/setup.go Normal file
View File

@@ -0,0 +1,173 @@
package controller
import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting/operation_setting"
"time"
)
type Setup struct {
Status bool `json:"status"`
RootInit bool `json:"root_init"`
DatabaseType string `json:"database_type"`
}
type SetupRequest struct {
Username string `json:"username"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
SelfUseModeEnabled bool `json:"SelfUseModeEnabled"`
DemoSiteEnabled bool `json:"DemoSiteEnabled"`
}
func GetSetup(c *gin.Context) {
setup := Setup{
Status: constant.Setup,
}
if constant.Setup {
c.JSON(200, gin.H{
"success": true,
"data": setup,
})
return
}
setup.RootInit = model.RootUserExists()
if common.UsingMySQL {
setup.DatabaseType = "mysql"
}
if common.UsingPostgreSQL {
setup.DatabaseType = "postgres"
}
if common.UsingSQLite {
setup.DatabaseType = "sqlite"
}
c.JSON(200, gin.H{
"success": true,
"data": setup,
})
}
func PostSetup(c *gin.Context) {
// Check if setup is already completed
if constant.Setup {
c.JSON(400, gin.H{
"success": false,
"message": "系统已经初始化完成",
})
return
}
// Check if root user already exists
rootExists := model.RootUserExists()
var req SetupRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(400, gin.H{
"success": false,
"message": "请求参数有误",
})
return
}
// If root doesn't exist, validate and create admin account
if !rootExists {
// Validate password
if req.Password != req.ConfirmPassword {
c.JSON(400, gin.H{
"success": false,
"message": "两次输入的密码不一致",
})
return
}
if len(req.Password) < 8 {
c.JSON(400, gin.H{
"success": false,
"message": "密码长度至少为8个字符",
})
return
}
// Create root user
hashedPassword, err := common.Password2Hash(req.Password)
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "系统错误: " + err.Error(),
})
return
}
rootUser := model.User{
Username: req.Username,
Password: hashedPassword,
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: nil,
Quota: 100000000,
}
err = model.DB.Create(&rootUser).Error
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "创建管理员账号失败: " + err.Error(),
})
return
}
}
// Set operation modes
operation_setting.SelfUseModeEnabled = req.SelfUseModeEnabled
operation_setting.DemoSiteEnabled = req.DemoSiteEnabled
// Save operation modes to database for persistence
err = model.UpdateOption("SelfUseModeEnabled", boolToString(req.SelfUseModeEnabled))
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "保存自用模式设置失败: " + err.Error(),
})
return
}
err = model.UpdateOption("DemoSiteEnabled", boolToString(req.DemoSiteEnabled))
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "保存演示站点模式设置失败: " + err.Error(),
})
return
}
// Update setup status
constant.Setup = true
setup := model.Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err = model.DB.Create(&setup).Error
if err != nil {
c.JSON(500, gin.H{
"success": false,
"message": "系统初始化失败: " + err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "系统初始化成功",
})
}
func boolToString(b bool) string {
if b {
return "true"
}
return "false"
}

View File

@@ -913,11 +913,12 @@ func TopUp(c *gin.Context) {
}
type UpdateUserSettingRequest struct {
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
QuotaWarningType string `json:"notify_type"`
QuotaWarningThreshold float64 `json:"quota_warning_threshold"`
WebhookUrl string `json:"webhook_url,omitempty"`
WebhookSecret string `json:"webhook_secret,omitempty"`
NotificationEmail string `json:"notification_email,omitempty"`
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
}
func UpdateUserSetting(c *gin.Context) {
@@ -993,6 +994,7 @@ func UpdateUserSetting(c *gin.Context) {
settings := map[string]interface{}{
constant.UserSettingNotifyType: req.QuotaWarningType,
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
}
// 如果是webhook类型,添加webhook相关设置

View File

@@ -11,7 +11,7 @@
- 类型为字符串,填写代理地址(例如 socks5 协议的代理地址)
3. thinking_to_content
- 用于标识是否将思考内容`reasoning_conetnt`转换为`<think>`标签拼接到内容中返回
- 用于标识是否将思考内容`reasoning_content`转换为`<think>`标签拼接到内容中返回
- 类型为布尔值,设置为 true 时启用思考内容转换
--------------------------------------------------------------
@@ -30,4 +30,4 @@
--------------------------------------------------------------
通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。
通过调整上述 JSON 配置中的值,可以灵活控制渠道的额外行为,比如是否进行格式化以及使用特定的网络代理。

212
dto/claude.go Normal file
View File

@@ -0,0 +1,212 @@
package dto
import "encoding/json"
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Text *string `json:"text,omitempty"`
Model string `json:"model,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson *string `json:"partial_json,omitempty"`
Role string `json:"role,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
}
func (c *ClaudeMediaMessage) SetText(s string) {
c.Text = &s
}
func (c *ClaudeMediaMessage) GetText() string {
if c.Text == nil {
return ""
}
return *c.Text
}
func (c *ClaudeMediaMessage) IsStringContent() bool {
var content string
return json.Unmarshal(c.Content, &content) == nil
}
func (c *ClaudeMediaMessage) GetStringContent() string {
var content string
if err := json.Unmarshal(c.Content, &content); err == nil {
return content
}
return ""
}
func (c *ClaudeMediaMessage) SetContent(content any) {
jsonContent, _ := json.Marshal(content)
c.Content = jsonContent
}
func (c *ClaudeMediaMessage) ParseMediaContent() []ClaudeMediaMessage {
var mediaContent []ClaudeMediaMessage
if err := json.Unmarshal(c.Content, &mediaContent); err == nil {
return mediaContent
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data any `json:"data"`
}
type ClaudeMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
func (c *ClaudeMessage) IsStringContent() bool {
_, ok := c.Content.(string)
return ok
}
func (c *ClaudeMessage) GetStringContent() string {
if c.IsStringContent() {
return c.Content.(string)
}
return ""
}
func (c *ClaudeMessage) SetStringContent(content string) {
c.Content = content
}
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
// map content to []ClaudeMediaMessage
// parse to json
jsonContent, _ := json.Marshal(c.Content)
var contentList []ClaudeMediaMessage
err := json.Unmarshal(jsonContent, &contentList)
if err != nil {
return make([]ClaudeMediaMessage, 0), err
}
return contentList, nil
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System any `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
func (c *ClaudeRequest) IsStringSystem() bool {
_, ok := c.System.(string)
return ok
}
func (c *ClaudeRequest) GetStringSystem() string {
if c.IsStringSystem() {
return c.System.(string)
}
return ""
}
func (c *ClaudeRequest) SetStringSystem(system string) {
c.System = system
}
func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
// map content to []ClaudeMediaMessage
// parse to json
jsonContent, _ := json.Marshal(c.System)
var contentList []ClaudeMediaMessage
if err := json.Unmarshal(jsonContent, &contentList); err == nil {
return contentList
}
return make([]ClaudeMediaMessage, 0)
}
type ClaudeError struct {
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
}
type ClaudeErrorWithStatusCode struct {
Error ClaudeError `json:"error"`
StatusCode int `json:"status_code"`
LocalError bool
}
type ClaudeResponse struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Role string `json:"role,omitempty"`
Content []ClaudeMediaMessage `json:"content,omitempty"`
Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"`
Error *ClaudeError `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
Delta *ClaudeMediaMessage `json:"delta,omitempty"`
Message *ClaudeMediaMessage `json:"message,omitempty"`
}
// set index
func (c *ClaudeResponse) SetIndex(i int) {
c.Index = &i
}
// get index
func (c *ClaudeResponse) GetIndex() int {
if c.Index == nil {
return 0
}
return *c.Index
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -113,9 +113,21 @@ type MediaContent struct {
InputAudio any `json:"input_audio,omitempty"`
}
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil {
return m.ImageUrl.(*MessageImageUrl)
}
return nil
}
type MessageImageUrl struct {
Url string `json:"url"`
Detail string `json:"detail"`
Url string `json:"url"`
Detail string `json:"detail"`
MimeType string
}
func (m *MessageImageUrl) IsRemoteImage() bool {
return strings.HasPrefix(m.Url, "http")
}
type MessageInputAudio struct {
@@ -244,43 +256,39 @@ func (m *Message) ParseContent() []MediaContent {
case ContentTypeImageURL:
imageUrl := contentItem["image_url"]
temp := &MessageImageUrl{
Detail: "high",
}
switch v := imageUrl.(type) {
case string:
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: v,
Detail: "high",
},
})
temp.Url = v
case map[string]interface{}:
url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string)
if !ok2 {
detail = "high"
if ok2 {
temp.Detail = detail
}
if ok1 {
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
Detail: detail,
},
})
temp.Url = url
}
}
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: temp,
})
case ContentTypeInputAudio:
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string)
if ok1 && ok2 {
temp := &MessageInputAudio{
Data: data,
Format: format,
}
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: data,
Format: format,
},
Type: ContentTypeInputAudio,
InputAudio: temp,
})
}
}

View File

@@ -1,20 +1,8 @@
package dto
type TextResponseWithError struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type SimpleResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
Error *OpenAIError `json:"error"`
}
type TextResponse struct {
@@ -38,6 +26,7 @@ type OpenAITextResponse struct {
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Error *OpenAIError `json:"error,omitempty"`
Usage `json:"usage"`
}
@@ -125,6 +114,20 @@ type ChatCompletionsStreamResponse struct {
Usage *Usage `json:"usage"`
}
func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
if len(c.Choices) == 0 {
return false
}
return len(c.Choices[0].Delta.ToolCalls) > 0
}
func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
if c.IsToolCall() {
return &c.Choices[0].Delta.ToolCalls[0]
}
return nil
}
func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
copy(choices, c.Choices)

View File

@@ -44,10 +44,11 @@ type RealtimeUsage struct {
}
type InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
CachedTokens int `json:"cached_tokens"`
CachedCreationTokens int
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
ImageTokens int `json:"image_tokens"`
}
type OutputTokenDetails struct {

View File

@@ -5,18 +5,29 @@ type RerankRequest struct {
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
OverLapTokens int `json:"overlap_tokens,omitempty"`
}
type RerankResponseDocument struct {
func (r *RerankRequest) GetReturnDocuments() bool {
if r.ReturnDocuments == nil {
return false
}
return *r.ReturnDocuments
}
type RerankResponseResult struct {
Document any `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type RerankResponse struct {
Results []RerankResponseDocument `json:"results"`
Usage Usage `json:"usage"`
type RerankDocument struct {
Text any `json:"text"`
}
type RerankResponse struct {
Results []RerankResponseResult `json:"results"`
Usage Usage `json:"usage"`
}

12
go.mod
View File

@@ -11,6 +11,7 @@ require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
github.com/bytedance/sonic v1.11.6
github.com/gin-contrib/cors v1.7.2
github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5
@@ -28,9 +29,9 @@ require (
github.com/samber/lo v1.39.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
golang.org/x/crypto v0.27.0
golang.org/x/crypto v0.35.0
golang.org/x/image v0.23.0
golang.org/x/net v0.28.0
golang.org/x/net v0.35.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -42,7 +43,6 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
@@ -84,9 +84,9 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.12.0 // indirect
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/sync v0.11.0 // indirect
golang.org/x/sys v0.30.0 // indirect
golang.org/x/text v0.22.0 // indirect
google.golang.org/protobuf v1.34.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect

20
go.sum
View File

@@ -217,18 +217,18 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg=
golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -239,14 +239,14 @@ golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=

View File

@@ -174,6 +174,14 @@ func TokenAuth() func(c *gin.Context) {
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
// 检查path包含/v1/messages
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
// 从x-api-key中获取key
key := c.Request.Header.Get("x-api-key")
if key != "" {
c.Request.Header.Set("Authorization", "Bearer "+key)
}
}
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ")

View File

@@ -212,6 +212,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
c.Set("channel_setting", channel.GetSetting())
c.Set("param_override", channel.GetParamOverride())
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}

View File

@@ -36,6 +36,7 @@ type Channel struct {
OtherInfo string `json:"other_info"`
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"`
ParamOverride *string `json:"param_override" gorm:"type:text"`
}
func (channel *Channel) GetModels() []string {
@@ -511,6 +512,17 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
channel.Setting = common.GetPointer[string](string(settingBytes))
}
func (channel *Channel) GetParamOverride() map[string]interface{} {
paramOverride := make(map[string]interface{})
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := json.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
if err != nil {
common.SysError("failed to unmarshal param override: " + err.Error())
}
}
return paramOverride
}
func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error

View File

@@ -1,16 +1,18 @@
package model
import (
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"log"
"one-api/common"
"one-api/constant"
"os"
"strings"
"sync"
"time"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var groupCol string
@@ -54,13 +56,40 @@ func createRootAccountIfNeed() error {
return nil
}
func checkSetup() {
setup := GetSetup()
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
common.SysLog("system is not initialized, but root user exists")
// Create setup record
newSetup := Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err := DB.Create(&newSetup).Error
if err != nil {
common.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
common.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
func chooseDB(envName string) (*gorm.DB, error) {
defer func() {
initCol()
}()
dsn := os.Getenv(envName)
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
@@ -213,8 +242,10 @@ func migrateDB() error {
if err != nil {
return err
}
err = DB.AutoMigrate(&Setup{})
common.SysLog("database migrated")
err = createRootAccountIfNeed()
checkSetup()
//err = createRootAccountIfNeed()
return err
}

16
model/setup.go Normal file
View File

@@ -0,0 +1,16 @@
package model
type Setup struct {
ID uint `json:"id" gorm:"primaryKey"`
Version string `json:"version" gorm:"type:varchar(50);not null"`
InitializedAt int64 `json:"initialized_at" gorm:"type:bigint;not null"`
}
func GetSetup() *Setup {
var setup Setup
err := DB.First(&setup).Error
if err != nil {
return nil
}
return &setup
}

View File

@@ -9,7 +9,6 @@ import (
"strings"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -24,6 +23,7 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
@@ -442,6 +442,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -473,6 +481,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsTelegramIdAlreadyTaken(telegramId string) bool {
return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1
}
@@ -796,3 +808,12 @@ func (user *User) FillUserByLinuxDOId() error {
err := DB.Where("linux_do_id = ?", user.LinuxDOId).First(user).Error
return err
}
func RootUserExists() bool {
var user User
err := DB.Where("role = ?", common.RoleRootUser).First(&user).Error
if err != nil {
return false
}
return true
}

View File

@@ -13,7 +13,7 @@ type Adaptor interface {
Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
@@ -22,6 +22,7 @@ type Adaptor interface {
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
}
type TaskAdaptor interface {

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
@@ -44,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -87,7 +93,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
}
return

View File

@@ -26,8 +26,8 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("/api/v1/tasks/%s", taskID)
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
var aliResponse AliResponse
@@ -36,7 +36,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
client := &http.Client{}
resp, err := client.Do(req)
@@ -58,7 +58,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliRes
return &response, nil, responseBody
}
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) {
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
waitSeconds := 3
step := 0
maxStep := 20
@@ -68,7 +68,7 @@ func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*Ali
for {
step++
rsp, err, body := updateTask(info, taskID, key)
rsp, err, body := updateTask(info, taskID)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
@@ -125,8 +125,6 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
@@ -148,7 +146,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey)
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
if err != nil {
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/gorilla/websocket"
"io"
"net/http"
common2 "one-api/common"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/service"
@@ -31,6 +32,9 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)

View File

@@ -20,6 +20,12 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -43,12 +49,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
var claudeReq *claude.ClaudeRequest
var claudeReq *dto.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {

View File

@@ -13,4 +13,41 @@ var awsModelIDMap = map[string]string{
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-3-sonnet-20240229-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-opus-20240229-v1:0": {
"us": true,
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
"us": true,
"ap": true,
},
"anthropic.claude-3-5-haiku-20241022-v1:0": {
"us": true,
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"us": true,
},
}
var awsRegionCrossModelPrefixMap = map[string]string{
"us": "us",
"eu": "eu",
"ap": "apac",
}
var ChannelName = "aws"

View File

@@ -1,25 +1,25 @@
package aws
import (
"one-api/relay/channel/claude"
"one-api/dto"
)
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
System string `json:"system,omitempty"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
AnthropicVersion string `json:"anthropic_version"`
System any `json:"system,omitempty"`
Messages []dto.ClaudeMessage `json:"messages"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest) *AwsClaudeRequest {
func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
return &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
System: req.System,

View File

@@ -1,21 +1,16 @@
package aws
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"io"
"net/http"
"one-api/common"
relaymodel "one-api/dto"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
@@ -39,15 +34,37 @@ func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.
return client, nil
}
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
return &relaymodel.OpenAIErrorWithStatusCode{
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
return &dto.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.OpenAIError{
Error: dto.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
func awsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
regionPrefix = parts[0]
}
return regionPrefix
}
func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
return exists && regionSet[awsRegionPrefix]
}
func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
if !find {
return awsModelId
}
return modelPrefix + "." + awsModelId
}
func awsModelID(requestModel string) (string, error) {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID, nil
@@ -56,7 +73,7 @@ func awsModelID(requestModel string) (string, error) {
return requestModel, nil
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -67,6 +84,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
@@ -77,7 +100,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
@@ -89,25 +112,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
}
claudeResponse := new(claude.ClaudeResponse)
err = json.Unmarshal(awsResp.Body, claudeResponse)
if err != nil {
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens,
CompletionTokens: claudeResponse.Usage.OutputTokens,
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
}
openaiResp.Usage = usage
c.JSON(http.StatusOK, openaiResp)
return nil, &usage
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
@@ -118,6 +135,12 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
return wrapErr(errors.Wrap(err, "awsModelID")), nil
}
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
@@ -128,7 +151,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
if !ok {
return wrapErr(errors.New("request not found")), nil
}
claudeReq := claudeReq_.(*claude.ClaudeRequest)
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
@@ -143,79 +166,31 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
stream := awsResp.GetStream()
defer stream.Close()
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
var model string
isFirst := true
createdTime := common.GetTimestamp()
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
return false
}
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
for event := range stream.Events() {
switch v := event.(type) {
case *types.ResponseStreamMemberChunk:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if respErr != nil {
return respErr, nil
}
claudeResp := new(claude.ClaudeResponse)
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return false
}
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
if claudeUsage != nil {
usage.PromptTokens += claudeUsage.InputTokens
usage.CompletionTokens += claudeUsage.OutputTokens
}
if response == nil {
return true
}
if response.Id != "" {
id = response.Id
}
if response.Model != "" {
model = response.Model
}
response.Created = createdTime
response.Id = id
response.Model = model
jsonStr, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case *types.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return false
return wrapErr(errors.New("unknown response type")), nil
default:
fmt.Println("union is nil or unknown type")
return false
}
})
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, createdTime, info.UpstreamModelName, usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
return wrapErr(errors.New("nil or unknown response type")), nil
}
}
helper.Done(c)
if resp != nil {
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
}
return nil, &usage
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
return nil, claudeInfo.Usage
}

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -104,7 +110,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -62,7 +68,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -22,6 +22,10 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -60,7 +64,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -1,94 +1,95 @@
package claude
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeMediaMessage struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
Thinking string `json:"thinking,omitempty"`
Signature string `json:"signature,omitempty"`
Delta string `json:"delta,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
}
type ClaudeMessageSource struct {
Type string `json:"type"`
MediaType string `json:"media_type"`
Data string `json:"data"`
}
type ClaudeMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
System string `json:"system,omitempty"`
Messages []ClaudeMessage `json:"messages,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *Thinking `json:"thinking,omitempty"`
}
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Id string `json:"id"`
Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only
ContentBlock *ClaudeMediaMessage `json:"content_block"`
Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
//
//type ClaudeMetadata struct {
// UserId string `json:"user_id"`
//}
//
//type ClaudeMediaMessage struct {
// Type string `json:"type"`
// Text string `json:"text,omitempty"`
// Source *ClaudeMessageSource `json:"source,omitempty"`
// Usage *ClaudeUsage `json:"usage,omitempty"`
// StopReason *string `json:"stop_reason,omitempty"`
// PartialJson string `json:"partial_json,omitempty"`
// Thinking string `json:"thinking,omitempty"`
// Signature string `json:"signature,omitempty"`
// Delta string `json:"delta,omitempty"`
// // tool_calls
// Id string `json:"id,omitempty"`
// Name string `json:"name,omitempty"`
// Input any `json:"input,omitempty"`
// Content string `json:"content,omitempty"`
// ToolUseId string `json:"tool_use_id,omitempty"`
//}
//
//type ClaudeMessageSource struct {
// Type string `json:"type"`
// MediaType string `json:"media_type"`
// Data string `json:"data"`
//}
//
//type ClaudeMessage struct {
// Role string `json:"role"`
// Content any `json:"content"`
//}
//
//type Tool struct {
// Name string `json:"name"`
// Description string `json:"description,omitempty"`
// InputSchema map[string]interface{} `json:"input_schema"`
//}
//
//type InputSchema struct {
// Type string `json:"type"`
// Properties any `json:"properties,omitempty"`
// Required any `json:"required,omitempty"`
//}
//
//type ClaudeRequest struct {
// Model string `json:"model"`
// Prompt string `json:"prompt,omitempty"`
// System string `json:"system,omitempty"`
// Messages []ClaudeMessage `json:"messages,omitempty"`
// MaxTokens uint `json:"max_tokens,omitempty"`
// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
// StopSequences []string `json:"stop_sequences,omitempty"`
// Temperature *float64 `json:"temperature,omitempty"`
// TopP float64 `json:"top_p,omitempty"`
// TopK int `json:"top_k,omitempty"`
// //ClaudeMetadata `json:"metadata,omitempty"`
// Stream bool `json:"stream,omitempty"`
// Tools any `json:"tools,omitempty"`
// ToolChoice any `json:"tool_choice,omitempty"`
// Thinking *Thinking `json:"thinking,omitempty"`
//}
//
//type Thinking struct {
// Type string `json:"type"`
// BudgetTokens int `json:"budget_tokens"`
//}
//
//type ClaudeError struct {
// Type string `json:"type"`
// Message string `json:"message"`
//}
//
//type ClaudeResponse struct {
// Id string `json:"id"`
// Type string `json:"type"`
// Content []ClaudeMediaMessage `json:"content"`
// Completion string `json:"completion"`
// StopReason string `json:"stop_reason"`
// Model string `json:"model"`
// Error ClaudeError `json:"error"`
// Usage ClaudeUsage `json:"usage"`
// Index int `json:"index"` // stream only
// ContentBlock *ClaudeMediaMessage `json:"content_block"`
// Delta *ClaudeMediaMessage `json:"delta"` // stream only
// Message *ClaudeResponse `json:"message"` // stream only: message_start
//}
//
//type ClaudeUsage struct {
// InputTokens int `json:"input_tokens"`
// OutputTokens int `json:"output_tokens"`
//}

View File

@@ -29,9 +29,9 @@ func stopReasonClaude2OpenAI(reason string) string {
}
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
claudeRequest := ClaudeRequest{
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
StopSequences: nil,
@@ -60,17 +60,19 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
return &claudeRequest
}
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTool := Tool{
claudeTool := dto.Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
}
claudeTool.InputSchema = make(map[string]interface{})
claudeTool.InputSchema["type"] = params["type"].(string)
if params["type"] != nil {
claudeTool.InputSchema["type"] = params["type"].(string)
}
claudeTool.InputSchema["properties"] = params["properties"]
claudeTool.InputSchema["required"] = params["required"]
for s, a := range params {
@@ -83,7 +85,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
}
claudeRequest := ClaudeRequest{
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
@@ -107,7 +109,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &Thinking{
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
@@ -165,7 +167,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
lastMessage = fmtMessage
}
claudeMessages := make([]ClaudeMessage, 0)
claudeMessages := make([]dto.ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages {
if message.Role == "system" {
@@ -186,63 +188,63 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := ClaudeMessage{
claudeMessage := dto.ClaudeMessage{
Role: "user",
Content: []ClaudeMediaMessage{
Content: []dto.ClaudeMediaMessage{
{
Type: "text",
Text: "...",
Text: common.GetPointer[string]("..."),
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := ClaudeMessage{
claudeMessage := dto.ClaudeMessage{
Role: message.Role,
}
if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []ClaudeMediaMessage{
lastMessage.Content = []dto.ClaudeMediaMessage{
{
Type: "text",
Text: content,
Text: common.GetPointer[string](content),
},
}
}
lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
Content: message.Content,
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []ClaudeMediaMessage{
claudeMessage.Content = []dto.ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.StringContent(),
Content: message.Content,
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]ClaudeMediaMessage, 0)
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := ClaudeMediaMessage{
claudeMediaMessage := dto.ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
imageUrl := mediaMessage.GetImageMedia()
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &ClaudeMessageSource{
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
@@ -272,7 +274,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
@@ -290,9 +292,8 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse
var claudeUsage *ClaudeUsage
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
@@ -308,7 +309,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
claudeUsage = &claudeResponse.Message.Usage
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
@@ -325,17 +326,17 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
})
}
} else {
return nil, nil
return nil
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
choice.Index = *claudeResponse.Index
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Function: dto.FunctionResponse{
Arguments: claudeResponse.Delta.PartialJson,
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
@@ -352,26 +353,23 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if finishReason != "null" {
choice.FinishReason = &finishReason
}
claudeUsage = &claudeResponse.Usage
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil, nil
return nil
} else {
return nil, nil
return nil
}
}
if claudeUsage == nil {
claudeUsage = &ClaudeUsage{}
}
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
response.Choices = append(response.Choices, choice)
return &response, claudeUsage
return &response
}
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -379,8 +377,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
Created: common.GetTimestamp(),
}
var responseText string
var responseThinking string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
responseText = claudeResponse.Content[0].GetText()
responseThinking = claudeResponse.Content[0].Thinking
}
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
@@ -415,7 +415,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.Text
responseText = message.GetText()
}
}
}
@@ -427,6 +427,9 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choice.SetStringContent(responseText)
if len(responseThinking) > 0 {
choice.ReasoningContent = responseThinking
}
if len(tools) > 0 {
choice.Message.SetToolCalls(tools)
}
@@ -437,126 +440,228 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
return &fullTextResponse
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
var usage *dto.Usage
usage = &dto.Usage{}
responseText := ""
createdTime := common.GetTimestamp()
type ClaudeResponseInfo struct {
ResponseId string
Created int64
Model string
ResponseText strings.Builder
Usage *dto.Usage
}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
} else if claudeResponse.Type == "message_delta" {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
if claudeResponse.Usage.InputTokens > 0 {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
}
}
if oaiResponse != nil {
oaiResponse.Id = claudeInfo.ResponseId
oaiResponse.Created = claudeInfo.Created
oaiResponse.Model = claudeInfo.Model
}
return true
}
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if response == nil {
return true
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Code: "stream_response_error",
Type: claudeResponse.Error.Type,
Message: claudeResponse.Error.Message,
},
StatusCode: http.StatusInternalServerError,
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
responseText += claudeResponse.Completion
responseId = response.Id
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
responseId = claudeResponse.Message.Id
info.UpstreamModelName = claudeResponse.Message.Model
usage.PromptTokens = claudeUsage.InputTokens
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
responseText += claudeResponse.Delta.Text
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
} else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
return true
} else {
return true
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
}
//response.Id = responseId
response.Id = responseId
response.Created = createdTime
response.Model = info.UpstreamModelName
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
return nil
}
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
}
return true
})
if requestMode == RequestModeCompletion {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
if usage.PromptTokens == 0 {
usage.PromptTokens = info.PromptTokens
}
if usage.CompletionTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(responseId, createdTime, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
//resp.Body.Close()
return nil, usage
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 说明流模式建立失败,可能为官方出错
if claudeInfo.Usage.PromptTokens == 0 {
//usage.PromptTokens = info.PromptTokens
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if requestMode == RequestModeCompletion {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 {
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse)
var err *dto.OpenAIErrorWithStatusCode
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil {
return false
}
return true
})
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return err, nil
}
if claudeResponse.Error.Type != "" {
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
return nil, claudeInfo.Usage
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
var claudeResponse dto.ClaudeResponse
err := common.DecodeJson(data, &claudeResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type,
Param: "",
Code: claudeResponse.Error.Type,
},
StatusCode: resp.StatusCode,
}, nil
StatusCode: http.StatusInternalServerError,
}
}
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
}
usage := dto.Usage{}
if requestMode == RequestModeCompletion {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = completionTokens
usage.TotalTokens = info.PromptTokens + completionTokens
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
}
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
} else {
usage.PromptTokens = claudeResponse.Usage.InputTokens
usage.CompletionTokens = claudeResponse.Usage.OutputTokens
usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
var responseData []byte
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
case relaycommon.RelayFormatClaude:
responseData = data
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(responseData)
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body.Close()
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
if handleErr != nil {
return handleErr, nil
}
return nil, claudeInfo.Usage
}

View File

@@ -17,6 +17,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}
@@ -59,7 +65,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = cohereRerankHandler(c, resp, info)

View File

@@ -1,6 +1,7 @@
package cohere
var ModelList = []string{
"command-a-03-2025",
"command-r", "command-r-plus",
"command-r-08-2024", "command-r-plus-08-2024",
"c4ai-aya-23-35b", "c4ai-aya-23-8b",

View File

@@ -40,8 +40,8 @@ type CohereRerankRequest struct {
}
type CohereRerankResponseResult struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta CohereMeta `json:"meta"`
Results []dto.RerankResponseResult `json:"results"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -44,7 +50,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -68,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -9,7 +9,6 @@ import (
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"strings"
)
const (
@@ -23,6 +22,12 @@ type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -34,15 +39,16 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "agent") {
a.BotType = BotTypeAgent
} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
a.BotType = BotTypeWorkFlow
} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
a.BotType = BotTypeCompletion
} else {
a.BotType = BotTypeChatFlow
}
//if strings.HasPrefix(info.UpstreamModelName, "agent") {
// a.BotType = BotTypeAgent
//} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
// a.BotType = BotTypeWorkFlow
//} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
// a.BotType = BotTypeCompletion
//} else {
//}
a.BotType = BotTypeChatFlow
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -64,11 +70,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Dify(*request), nil
return requestOpenAI2Dify(c, info, *request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

View File

@@ -8,6 +8,14 @@ type DifyChatRequest struct {
ResponseMode string `json:"response_mode"`
User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"`
Files []DifyFile `json:"files"`
}
type DifyFile struct {
Type string `json:"type"`
TransferMode string `json:"transfer_mode"`
URL string `json:"url,omitempty"`
UploadFileId string `json:"upload_file_id,omitempty"`
}
type DifyMetaData struct {
@@ -17,6 +25,8 @@ type DifyMetaData struct {
type DifyData struct {
WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"`
NodeType string `json:"node_type"`
Status string `json:"status"`
}
type DifyChatCompletionResponse struct {

View File

@@ -2,9 +2,12 @@ package dify
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/json"
"github.com/gin-gonic/gin"
"fmt"
"io"
"mime/multipart"
"net/http"
"one-api/common"
"one-api/constant"
@@ -12,35 +15,163 @@ import (
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"os"
"strings"
"github.com/gin-gonic/gin"
)
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest {
content := ""
for _, message := range request.Messages {
if message.Role == "system" {
content += "SYSTEM: \n" + message.StringContent() + "\n"
} else if message.Role == "assistant" {
content += "ASSISTANT: \n" + message.StringContent() + "\n"
} else {
content += "USER: \n" + message.StringContent() + "\n"
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
switch media.Type {
case dto.ContentTypeImageURL:
// Decode base64 data
imageMedia := media.GetImageMedia()
base64Data := imageMedia.Url
// Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
if idx := strings.Index(base64Data, ","); idx != -1 {
base64Data = base64Data[idx+1:]
}
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
common.SysError("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
common.SysError("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
defer os.Remove(tempFile.Name())
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
common.SysError("failed to write to temp file: " + err.Error())
return nil
}
// Create multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add user field
if err := writer.WriteField("user", user); err != nil {
common.SysError("failed to add user field: " + err.Error())
return nil
}
// Create form file with proper mime type
mimeType := imageMedia.MimeType
if mimeType == "" {
mimeType = "image/jpeg" // default mime type
}
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
common.SysError("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
common.SysError("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
common.SysError("failed to create request: " + err.Error())
return nil
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request
client := service.GetImpatientHttpClient()
resp, err := client.Do(req)
if err != nil {
common.SysError("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
// Parse response
var result struct {
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
common.SysError("failed to decode response: " + err.Error())
return nil
}
return &DifyFile{
UploadFileId: result.Id,
Type: "image",
TransferMode: "local_file",
}
}
return nil
}
func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
difyReq := DifyChatRequest{
Inputs: make(map[string]interface{}),
AutoGenerateName: false,
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
difyReq.User = user
files := make([]DifyFile, 0)
var content strings.Builder
for _, message := range request.Messages {
if message.Role == "system" {
content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
} else if message.Role == "assistant" {
content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
} else {
parseContent := message.ParseContent()
for _, mediaContent := range parseContent {
switch mediaContent.Type {
case dto.ContentTypeText:
content.WriteString("USER: \n" + mediaContent.Text + "\n")
case dto.ContentTypeImageURL:
media := mediaContent.GetImageMedia()
var file *DifyFile
if media.IsRemoteImage() {
file.Type = media.MimeType
file.TransferMode = "remote_url"
file.URL = media.Url
} else {
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
}
if file != nil {
files = append(files, *file)
}
}
}
}
}
difyReq.Query = content.String()
difyReq.Files = files
mode := "blocking"
if request.Stream {
mode = "streaming"
}
user := request.User
if user == "" {
user = "api-user"
}
return &DifyChatRequest{
Inputs: make(map[string]interface{}),
Query: content,
ResponseMode: mode,
User: user,
AutoGenerateName: false,
}
difyReq.ResponseMode = mode
return &difyReq
}
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
@@ -50,11 +181,29 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
Model: "dify",
}
var choice dto.ChatCompletionsStreamResponseChoice
if constant.DifyDebug && difyResponse.Event == "workflow_started" {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
if strings.HasPrefix(difyResponse.Event, "workflow_") {
if constant.DifyDebug {
text := "Workflow: " + difyResponse.Data.WorkflowId
if difyResponse.Event == "workflow_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if strings.HasPrefix(difyResponse.Event, "node_") {
if constant.DifyDebug {
text := "Node: " + difyResponse.Data.NodeType
if difyResponse.Event == "node_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
if difyResponse.Answer == "<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>\n" {
difyResponse.Answer = "<think>"
} else if difyResponse.Answer == "</details>" {
difyResponse.Answer = "</think>"
}
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)
@@ -66,38 +215,38 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
var nodeToken int
helper.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
return true
}
var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage
break
return false
} else if difyResponse.Event == "error" {
break
return false
} else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
nodeToken += 1
}
}
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
}
return true
})
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
@@ -112,6 +261,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
usage.CompletionTokens += nodeToken
return nil, usage
}

View File

@@ -21,6 +21,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -70,6 +76,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
@@ -83,7 +95,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -99,8 +111,37 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
if request.Input == nil {
return nil, errors.New("input is required")
}
inputs := request.ParseInput()
if len(inputs) == 0 {
return nil, errors.New("input is empty")
}
// only process the first input
geminiRequest := GeminiEmbeddingRequest{
Content: GeminiChatContent{
Parts: []GeminiPart{
{
Text: inputs[0],
},
},
},
}
// set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName {
case "text-embedding-004":
// except embedding-001 supports setting `OutputDimensionality`
if request.Dimensions > 0 {
geminiRequest.OutputDimensionality = request.Dimensions
}
}
return geminiRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
@@ -112,6 +153,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return GeminiImageHandler(c, resp, info)
}
// check if the model is an embedding model
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return GeminiEmbeddingHandler(c, resp, info)
}
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {

View File

@@ -16,8 +16,14 @@ var ModelList = []string{
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
// imagen models
"imagen-3.0-generate-002",
// embedding models
"gemini-embedding-exp-03-07",
"text-embedding-004",
"embedding-001",
}
var SafetySettingList = []string{

View File

@@ -136,3 +136,19 @@ type GeminiImagePrediction struct {
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}
// Embedding related structs
type GeminiEmbeddingRequest struct {
Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
type GeminiEmbeddingResponse struct {
Embedding ContentEmbedding `json:"embedding"`
}
type ContentEmbedding struct {
Values []float64 `json:"values"`
}

View File

@@ -180,9 +180,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url获取图片的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
@@ -193,7 +193,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
},
})
} else {
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
@@ -580,3 +580,52 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiEmbeddingResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: []dto.OpenAIEmbeddingResponseItem{
{
Object: "embedding",
Embedding: geminiResponse.Embedding.Values,
Index: 0,
},
},
Model: info.UpstreamModelName,
}
// calculate usage
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
// Google has not yet clarified how embedding models will be billed
// refer to openai billing method to use input tokens billing
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
usage = &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
openAIResponse.Usage = *usage.(*dto.Usage)
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return usage, nil
}

View File

@@ -8,13 +8,21 @@ import (
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -43,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
@@ -61,9 +69,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeRerank {
err, usage = JinaRerankHandler(c, resp)
err, usage = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
err, usage = jinaEmbeddingHandler(c, resp)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -1,60 +1 @@
package jina
import (
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/service"
)
func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.RerankResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}
func jinaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var jinaResp dto.OpenAIEmbeddingResponse
err = json.Unmarshal(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jsonResponse, err := json.Marshal(jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &jinaResp.Usage
}

View File

@@ -14,6 +14,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -37,7 +43,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -61,7 +67,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -10,7 +10,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
imageUrl := mediaMessage.GetImageMedia()
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -51,7 +57,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -73,13 +79,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = mokaEmbeddingHandler(c, resp)
default:
// err, usage = mokaHandler(c, resp)
}
return
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -43,7 +49,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -69,7 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayMode == relayconstant.RelayModeEmbeddings {
err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
}
return

View File

@@ -19,7 +19,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, err
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
imageUrl := mediaMessage.GetImageMedia()
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"mime/multipart"
"net/http"
@@ -14,13 +13,18 @@ import (
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/jina"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
@@ -28,11 +32,39 @@ type Adaptor struct {
ResponseFormat string
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if !strings.Contains(request.Model, "claude") {
return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
}
aiRequest, err := service.ClaudeToOpenAIRequest(*request)
if err != nil {
return nil, err
}
if info.SupportStreamOptions {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
return a.ConvertOpenAIRequest(c, info, aiRequest)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
HasSentThinkingContent: false,
}
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == constant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
@@ -101,14 +133,14 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
//if info.ChannelType == common.ChannelTypeOpenRouter {
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
// req.Header.Set("X-Title", "One API")
//}
if info.ChannelType == common.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
header.Set("X-Title", "New API")
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -230,12 +262,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case constant.RelayModeImagesGenerations:
err, usage = OpenaiTTSHandler(c, resp, info)
case constant.RelayModeRerank:
err, usage = jina.JinaRerankHandler(c, resp)
err, usage = common_handler.RerankHandler(c, info, resp)
default:
if info.IsStream {
err, usage = OaiStreamHandler(c, resp, info)
} else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = OpenaiHandler(c, resp, info)
}
}
return
@@ -251,6 +283,10 @@ func (a *Adaptor) GetModelList() []string {
return lingyiwanwu.ModelList
case common.ChannelTypeMiniMax:
return minimax.ModelList
case common.ChannelTypeXinference:
return xinference.ModelList
case common.ChannelTypeOpenRouter:
return openrouter.ModelList
default:
return ModelList
}
@@ -266,6 +302,10 @@ func (a *Adaptor) GetChannelName() string {
return lingyiwanwu.ChannelName
case common.ChannelTypeMiniMax:
return minimax.ChannelName
case common.ChannelTypeXinference:
return xinference.ChannelName
case common.ChannelTypeOpenRouter:
return openrouter.ChannelName
default:
return ChannelName
}

View File

@@ -0,0 +1,188 @@
package openai
import (
"encoding/json"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// 辅助函数
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
return nil
}
func processStreamResponse(item string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
if err := processStreamResponse(item, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
systemFingerprint *string, model *string, usage **dto.Usage,
containStreamUsage *bool, info *relaycommon.RelayInfo,
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
*responseId = lastStreamResponse.Id
*createAt = lastStreamResponse.Created
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
*model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
*shouldSendLastResp = false
}
}
return nil
}
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
if !containStreamUsage {
streamResponse.Usage = usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
}
}

View File

@@ -12,7 +12,6 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"os"
@@ -34,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &lastStreamResponse); err != nil {
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
return err
}
@@ -66,6 +65,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
response.Choices[i].Delta.Reasoning = nil
}
info.ThinkingContentInfo.IsFirstThinkingContent = false
info.ThinkingContentInfo.HasSentThinkingContent = true
return helper.ObjectData(c, response)
}
}
@@ -77,7 +77,8 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent {
// only send `</think>` tag when previous thinking content has been sent
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>\n")
@@ -88,7 +89,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
helper.ObjectData(c, response)
}
// Convert reasoning content to regular content
// Convert reasoning content to regular content if any
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
@@ -137,10 +138,11 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
err := sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.LogError(c, "streaming error: "+err.Error())
common.SysError("error handling stream format: " + err.Error())
}
info.SetFirstResponseTime()
}
lastStreamData = data
streamItems = append(streamItems, data)
@@ -149,7 +151,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
shouldSendLastResp := true
var lastStreamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
if err == nil {
responseId = lastStreamResponse.Id
createAt = lastStreamResponse.Created
@@ -172,83 +174,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 计算token
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
var streamResponses []dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
// handle both reasoning_content and reasoning
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
} else {
for _, streamResponse := range streamResponses {
//if service.ValidUsage(streamResponse.Usage) {
// usage = streamResponse.Usage
// containStreamUsage = true
//}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent()) // This will handle both reasoning_content and reasoning
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > toolCount {
toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
}
case relayconstant.RelayModeCompletions:
var streamResponses []dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
if err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
if err == nil {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
} else {
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.SysError("error processing tokens: " + err.Error())
}
if !containStreamUsage {
@@ -262,20 +190,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
}
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
helper.Done(c)
//resp.Body.Close()
return nil, usage
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var simpleResponse dto.SimpleResponse
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -284,16 +205,29 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &simpleResponse)
err = common.DecodeJson(responseBody, &simpleResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if simpleResponse.Error.Type != "" {
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: simpleResponse.Error,
Error: *simpleResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
break
case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := json.Marshal(claudeResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
responseBody = claudeRespStr
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
@@ -306,19 +240,20 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
common.SysError("error copying response body: " + err.Error())
}
resp.Body.Close()
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, model)
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
PromptTokens: promptTokens,
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
}
return nil, &simpleResponse.Usage

View File

@@ -1,74 +0,0 @@
package openrouter
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
req.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
req.Set("X-Title", "New API")
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -54,7 +60,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -57,7 +63,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -66,7 +71,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -16,6 +16,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -48,7 +54,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
@@ -72,16 +78,16 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
case constant.RelayModeCompletions:
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -12,6 +12,6 @@ type SFMeta struct {
}
type SFRerankResponse struct {
Results []dto.RerankResponseDocument `json:"results"`
Meta SFMeta `json:"meta"`
Results []dto.RerankResponseResult `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -23,6 +23,12 @@ type Adaptor struct {
Timestamp int64
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -52,7 +58,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -78,7 +84,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -38,6 +38,16 @@ type Adaptor struct {
AccountCredentials Credentials
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
} else {
c.Set("request_model", request.Model)
}
vertexClaudeReq := copyRequest(request, anthropicVersion)
return vertexClaudeReq, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -119,7 +129,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -175,7 +185,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp, info)
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
}
return

View File

@@ -1,25 +1,25 @@
package vertex
import (
"one-api/relay/channel/claude"
"one-api/dto"
)
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []claude.ClaudeMessage `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *claude.Thinking `json:"thinking,omitempty"`
AnthropicVersion string `json:"anthropic_version"`
Messages []dto.ClaudeMessage `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *claude.ClaudeRequest, version string) *VertexAIClaudeRequest {
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
return &VertexAIClaudeRequest{
AnthropicVersion: version,
System: req.System,

View File

@@ -17,6 +17,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -50,7 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -75,10 +81,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
case constant.RelayModeEmbeddings:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -0,0 +1,8 @@
package xinference
var ModelList = []string{
"bge-reranker-v2-m3",
"jina-reranker-v2",
}
var ChannelName = "xinference"

View File

@@ -0,0 +1,11 @@
package xinference
type XinRerankResponseDocument struct {
Document string `json:"document,omitempty"`
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
}
type XinRerankResponse struct {
Results []XinRerankResponseDocument `json:"results"`
}

View File

@@ -16,6 +16,12 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -38,7 +44,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -55,7 +61,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -14,6 +14,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -42,7 +48,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -61,7 +67,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}

View File

@@ -15,6 +15,12 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
@@ -39,7 +45,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
@@ -58,7 +64,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
@@ -67,7 +72,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = openai.OaiStreamHandler(c, resp, info)
} else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage = openai.OpenaiHandler(c, resp, info)
}
return
}

View File

@@ -79,7 +79,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
imageUrl := mediaMessage.GetImageMedia()
// check if base64
if strings.HasPrefix(imageUrl.Url, "data:image/") {
// 去除base64数据的URL前缀如果有

163
relay/claude_handler.go Normal file
View File

@@ -0,0 +1,163 @@
package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"strings"
)
func getAndValidateClaudeRequest(c *gin.Context) (textRequest *dto.ClaudeRequest, err error) {
textRequest = &dto.ClaudeRequest{}
err = c.ShouldBindJSON(textRequest)
if err != nil {
return nil, err
}
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return nil, errors.New("field messages is required")
}
if textRequest.Model == "" {
return nil, errors.New("field model is required")
}
return textRequest, nil
}
func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfoClaude(c)
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateClaudeRequest(c)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "invalid_claude_request", http.StatusBadRequest)
}
if textRequest.Stream {
relayInfo.IsStream = true
}
err = helper.ModelMappedHelper(c, relayInfo)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
}
textRequest.Model = relayInfo.UpstreamModelName
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "count_token_messages_failed", http.StatusInternalServerError)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return service.OpenAIErrorToClaudeError(openaiErr)
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return service.ClaudeErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if textRequest.MaxTokens == 0 {
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
if textRequest.Thinking == nil {
// 因为BudgetTokens 必须大于1024
if textRequest.MaxTokens < 1280 {
textRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
textRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
textRequest.TopP = 0
textRequest.Temperature = common.GetPointer[float64](1.0)
}
textRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
relayInfo.UpstreamModelName = textRequest.Model
}
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
}
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return service.OpenAIErrorToClaudeError(openaiErr)
}
}
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
//log.Printf("usage: %v", usage)
if openaiErr != nil {
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return service.OpenAIErrorToClaudeError(openaiErr)
}
service.PostClaudeConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}
func getClaudePromptTokens(textRequest *dto.ClaudeRequest, info *relaycommon.RelayInfo) (int, error) {
var promptTokens int
var err error
switch info.RelayMode {
default:
promptTokens, err = service.CountTokenClaudeRequest(*textRequest, info.UpstreamModelName)
}
info.PromptTokens = promptTokens
return promptTokens, err
}

View File

@@ -15,6 +15,27 @@ import (
type ThinkingContentInfo struct {
IsFirstThinkingContent bool
SendLastThinkingContent bool
HasSentThinkingContent bool
}
const (
LastMessageTypeText = "text"
LastMessageTypeTools = "tools"
)
type ClaudeConvertInfo struct {
LastMessagesType string
Index int
}
const (
RelayFormatOpenAI = "openai"
RelayFormatClaude = "claude"
)
type RerankerInfo struct {
Documents []any
ReturnDocuments bool
}
type RelayInfo struct {
@@ -55,10 +76,15 @@ type RelayInfo struct {
AudioUsage bool
ReasoningEffort string
ChannelSetting map[string]interface{}
ParamOverride map[string]interface{}
UserSetting map[string]interface{}
UserEmail string
UserQuota int
RelayFormat string
SendResponseCount int
ThinkingContentInfo
ClaudeConvertInfo
*RerankerInfo
}
// 定义支持流式选项的通道类型
@@ -82,10 +108,31 @@ func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
return info
}
func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
info := GenRelayInfo(c)
info.RelayFormat = RelayFormatClaude
info.ShouldIncludeUsage = false
info.ClaudeConvertInfo = ClaudeConvertInfo{
LastMessagesType: LastMessageTypeText,
}
return info
}
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
info := GenRelayInfo(c)
info.RelayMode = relayconstant.RelayModeRerank
info.RerankerInfo = &RerankerInfo{
Documents: req.Documents,
ReturnDocuments: req.GetReturnDocuments(),
}
return info
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
channelSetting := c.GetStringMap("channel_setting")
paramOverride := c.GetStringMap("param_override")
tokenId := c.GetInt("token_id")
tokenKey := c.GetString("token_key")
@@ -123,6 +170,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Organization: c.GetString("channel_organization"),
ChannelSetting: channelSetting,
ParamOverride: paramOverride,
RelayFormat: RelayFormatOpenAI,
ThinkingContentInfo: ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,

View File

@@ -0,0 +1,68 @@
package common_handler
import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/service"
)
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if common.DebugEnabled {
println("reranker response body: ", string(responseBody))
}
var jinaResp dto.RerankResponse
if info.ChannelType == common.ChannelTypeXinference {
var xinRerankResponse xinference.XinRerankResponse
err = common.DecodeJson(responseBody, &xinRerankResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
respResult := dto.RerankResponseResult{
Index: result.Index,
RelevanceScore: result.RelevanceScore,
}
if info.ReturnDocuments {
var document any
if result.Document == "" {
document = info.Documents[result.Index]
} else {
document = result.Document
}
respResult.Document = document
}
jinaRespResults[i] = respResult
}
jinaResp = dto.RerankResponse{
Results: jinaRespResults,
Usage: dto.Usage{
PromptTokens: info.PromptTokens,
TotalTokens: info.PromptTokens,
},
}
} else {
err = common.DecodeJson(responseBody, &jinaResp)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}
c.Writer.Header().Set("Content-Type", "application/json")
c.JSON(http.StatusOK, jinaResp)
return nil, &jinaResp.Usage
}

View File

@@ -31,6 +31,7 @@ const (
APITypeVolcEngine
APITypeBaiduV2
APITypeOpenRouter
APITypeXinference
APITypeDummy // this one is only for count, do not add any channel after this
)
@@ -89,6 +90,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
apiType = APITypeBaiduV2
case common.ChannelTypeOpenRouter:
apiType = APITypeOpenRouter
case common.ChannelTypeXinference:
apiType = APITypeXinference
}
if apiType == -1 {
return APITypeOpenAI, false

View File

@@ -19,6 +19,30 @@ func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("X-Accel-Buffering", "no")
}
func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
jsonData, err := json.Marshal(resp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
} else {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
}
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
func StringData(c *gin.Context, str string) error {
//str = strings.TrimPrefix(str, "data: ")
//str = strings.TrimSuffix(str, "\r")

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"one-api/common"
constant2 "one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting"
"one-api/setting/operation_setting"
@@ -16,9 +17,14 @@ type PriceData struct {
CacheRatio float64
GroupRatio float64
UsePrice bool
CacheCreationRatio float64
ShouldPreConsumedQuota int
}
func (p PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota)
}
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
groupRatio := setting.GetGroupRatio(info.Group)
@@ -26,6 +32,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var modelRatio float64
var completionRatio float64
var cacheRatio float64
var cacheCreationRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
if maxTokens != 0 {
@@ -34,26 +41,44 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var success bool
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
if !success {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
acceptUnsetRatio := false
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
b, ok := accept.(bool)
if ok {
acceptUnsetRatio = b
}
}
if !acceptUnsetRatio {
if info.UserId == 1 {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置请设置或开始自用模式Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
} else {
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置, 请联系管理员设置Model %s ratio or price not set, please contact administrator to set", info.OriginModelName, info.OriginModelName)
}
}
}
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
ratio := modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
return PriceData{
priceData := PriceData{
ModelPrice: modelPrice,
ModelRatio: modelRatio,
CompletionRatio: completionRatio,
GroupRatio: groupRatio,
UsePrice: usePrice,
CacheRatio: cacheRatio,
CacheCreationRatio: cacheCreationRatio,
ShouldPreConsumedQuota: preConsumedQuota,
}, nil
}
if common.DebugEnabled {
println(fmt.Sprintf("model_price_helper result: %s", priceData.ToSetting()))
}
return priceData, nil
}

View File

@@ -117,7 +117,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr

View File

@@ -155,7 +155,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr := service.RelayErrorHandler(httpResp)
openaiErr := service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr

View File

@@ -17,6 +17,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"strings"
"time"
@@ -108,7 +109,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}
@@ -152,38 +153,57 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
adaptor.Init(relayInfo)
var requestBody io.Reader
//if relayInfo.ChannelType == common.ChannelTypeOpenAI && !isModelMapped {
// body, err := common.GetRequestBody(c)
// if err != nil {
// return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
// }
// requestBody = bytes.NewBuffer(body)
//} else {
//
//}
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_request_body_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_unmarshal_failed", http.StatusInternalServerError)
}
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "param_override_marshal_failed", http.StatusInternalServerError)
}
}
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
requestBody = bytes.NewBuffer(jsonData)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
if resp != nil {
httpResp = resp.(*http.Response)
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr
@@ -369,17 +389,18 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := service.PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"

View File

@@ -18,7 +18,6 @@ import (
"one-api/relay/channel/mokaai"
"one-api/relay/channel/ollama"
"one-api/relay/channel/openai"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/siliconflow"
@@ -34,8 +33,6 @@ import (
func GetAdaptor(apiType int) channel.Adaptor {
switch apiType {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
case constant.APITypeAli:
return &ali.Adaptor{}
case constant.APITypeAnthropic:
@@ -85,7 +82,9 @@ func GetAdaptor(apiType int) channel.Adaptor {
case constant.APITypeBaiduV2:
return &baidu_v2.Adaptor{}
case constant.APITypeOpenRouter:
return &openrouter.Adaptor{}
return &openai.Adaptor{}
case constant.APITypeXinference:
return &openai.Adaptor{}
}
return nil
}

View File

@@ -98,7 +98,7 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr

View File

@@ -25,7 +25,6 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
}
func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
var rerankRequest *dto.RerankRequest
err := common.UnmarshalBodyReusable(c, &rerankRequest)
@@ -33,6 +32,9 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
if rerankRequest.Query == "" {
return service.OpenAIErrorWrapperLocal(fmt.Errorf("query is empty"), "invalid_query", http.StatusBadRequest)
}
@@ -90,7 +92,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
openaiErr = service.RelayErrorHandler(httpResp)
openaiErr = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
return openaiErr

View File

@@ -13,6 +13,8 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
{
apiRouter.GET("/setup", controller.GetSetup)
apiRouter.POST("/setup", controller.PostSetup)
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
@@ -25,6 +27,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)

View File

@@ -35,6 +35,7 @@ func SetRelayRouter(router *gin.Engine) {
//http router
httpRouter := relayV1Router.Group("")
httpRouter.Use(middleware.Distribute())
httpRouter.POST("/messages", controller.RelayClaude)
httpRouter.POST("/completions", controller.Relay)
httpRouter.POST("/chat/completions", controller.Relay)
httpRouter.POST("/edits", controller.Relay)

351
service/convert.go Normal file
View File

@@ -0,0 +1,351 @@
package service
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
)
func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIRequest, error) {
openAIRequest := dto.GeneralOpenAIRequest{
Model: claudeRequest.Model,
MaxTokens: claudeRequest.MaxTokens,
Temperature: claudeRequest.Temperature,
TopP: claudeRequest.TopP,
Stream: claudeRequest.Stream,
}
// Convert stop sequences
if len(claudeRequest.StopSequences) == 1 {
openAIRequest.Stop = claudeRequest.StopSequences[0]
} else if len(claudeRequest.StopSequences) > 1 {
openAIRequest.Stop = claudeRequest.StopSequences
}
// Convert tools
tools, _ := common.Any2Type[[]dto.Tool](claudeRequest.Tools)
openAITools := make([]dto.ToolCallRequest, 0)
for _, claudeTool := range tools {
openAITool := dto.ToolCallRequest{
Type: "function",
Function: dto.FunctionRequest{
Name: claudeTool.Name,
Description: claudeTool.Description,
Parameters: claudeTool.InputSchema,
},
}
openAITools = append(openAITools, openAITool)
}
openAIRequest.Tools = openAITools
// Convert messages
openAIMessages := make([]dto.Message, 0)
// Add system message if present
if claudeRequest.System != nil {
if claudeRequest.IsStringSystem() {
openAIMessage := dto.Message{
Role: "system",
}
openAIMessage.SetStringContent(claudeRequest.GetStringSystem())
openAIMessages = append(openAIMessages, openAIMessage)
} else {
systems := claudeRequest.ParseSystem()
if len(systems) > 0 {
systemStr := ""
openAIMessage := dto.Message{
Role: "system",
}
for _, system := range systems {
systemStr += system.Type
}
openAIMessage.SetStringContent(systemStr)
openAIMessages = append(openAIMessages, openAIMessage)
}
}
}
for _, claudeMessage := range claudeRequest.Messages {
openAIMessage := dto.Message{
Role: claudeMessage.Role,
}
//log.Printf("claudeMessage.Content: %v", claudeMessage.Content)
if claudeMessage.IsStringContent() {
openAIMessage.SetStringContent(claudeMessage.GetStringContent())
} else {
content, err := claudeMessage.ParseContent()
if err != nil {
return nil, err
}
contents := content
var toolCalls []dto.ToolCallRequest
mediaMessages := make([]dto.MediaContent, 0, len(contents))
for _, mediaMsg := range contents {
switch mediaMsg.Type {
case "text":
message := dto.MediaContent{
Type: "text",
Text: mediaMsg.GetText(),
}
mediaMessages = append(mediaMessages, message)
case "image":
// Handle image conversion (base64 to URL or keep as is)
imageData := fmt.Sprintf("data:%s;base64,%s", mediaMsg.Source.MediaType, mediaMsg.Source.Data)
//textContent += fmt.Sprintf("[Image: %s]", imageData)
mediaMessage := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{Url: imageData},
}
mediaMessages = append(mediaMessages, mediaMessage)
case "tool_use":
toolCall := dto.ToolCallRequest{
ID: mediaMsg.Id,
Type: "function",
Function: dto.FunctionRequest{
Name: mediaMsg.Name,
Arguments: toJSONString(mediaMsg.Input),
},
}
toolCalls = append(toolCalls, toolCall)
case "tool_result":
// Add tool result as a separate message
oaiToolMessage := dto.Message{
Role: "tool",
Name: &mediaMsg.Name,
ToolCallId: mediaMsg.ToolUseId,
}
//oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
if mediaMsg.IsStringContent() {
oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
} else {
mediaContents := mediaMsg.ParseMediaContent()
if len(mediaContents) > 0 && mediaContents[0].Text != nil {
oaiToolMessage.SetStringContent(*mediaContents[0].Text)
}
}
openAIMessages = append(openAIMessages, oaiToolMessage)
}
}
if len(mediaMessages) > 0 {
openAIMessage.SetMediaContent(mediaMessages)
}
if len(toolCalls) > 0 {
openAIMessage.SetToolCalls(toolCalls)
}
}
if len(openAIMessage.ParseContent()) > 0 {
openAIMessages = append(openAIMessages, openAIMessage)
}
}
openAIRequest.Messages = openAIMessages
return &openAIRequest, nil
}
func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
claudeError := dto.ClaudeError{
Type: "new_api_error",
Message: openAIError.Error.Message,
}
return &dto.ClaudeErrorWithStatusCode{
Error: claudeError,
StatusCode: openAIError.StatusCode,
}
}
func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
openAIError := dto.OpenAIError{
Message: claudeError.Error.Message,
Type: "new_api_error",
}
return &dto.OpenAIErrorWithStatusCode{
Error: openAIError,
StatusCode: claudeError.StatusCode,
}
}
func generateStopBlock(index int) *dto.ClaudeResponse {
return &dto.ClaudeResponse{
Type: "content_block_stop",
Index: common.GetPointer[int](index),
}
}
func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) []*dto.ClaudeResponse {
var claudeResponses []*dto.ClaudeResponse
if info.SendResponseCount == 1 {
msg := &dto.ClaudeMediaMessage{
Id: openAIResponse.Id,
Model: openAIResponse.Model,
Type: "message",
Role: "assistant",
Usage: &dto.ClaudeUsage{
InputTokens: info.PromptTokens,
OutputTokens: 0,
},
}
msg.SetContent(make([]any, 0))
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_start",
Message: msg,
})
claudeResponses = append(claudeResponses)
//claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
// Type: "ping",
//})
if openAIResponse.IsToolCall() {
resp := &dto.ClaudeResponse{
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: openAIResponse.GetFirstToolCall().ID,
Type: "tool_use",
Name: openAIResponse.GetFirstToolCall().Function.Name,
},
}
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
resp := &dto.ClaudeResponse{
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
}
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
}
return claudeResponses
}
if len(openAIResponse.Choices) == 0 {
// no choices
// TODO: handle this case
return claudeResponses
} else {
chosenChoice := openAIResponse.Choices[0]
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
if openAIResponse.Usage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: openAIResponse.Usage.PromptTokens,
OutputTokens: openAIResponse.Usage.CompletionTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)),
},
})
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_stop",
})
} else {
var claudeResponse dto.ClaudeResponse
claudeResponse.SetIndex(0)
claudeResponse.Type = "content_block_delta"
if len(chosenChoice.Delta.ToolCalls) > 0 {
if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
info.ClaudeConvertInfo.Index++
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: openAIResponse.GetFirstToolCall().ID,
Type: "tool_use",
Name: openAIResponse.GetFirstToolCall().Function.Name,
Input: map[string]interface{}{},
},
})
}
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
// tools delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "input_json_delta",
PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
}
} else {
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
// text delta
claudeResponse.Delta = &dto.ClaudeMediaMessage{
Type: "text_delta",
Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
}
}
claudeResponse.Index = &info.ClaudeConvertInfo.Index
claudeResponses = append(claudeResponses, &claudeResponse)
}
}
return claudeResponses
}
func ResponseOpenAI2Claude(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.ClaudeResponse {
var stopReason string
contents := make([]dto.ClaudeMediaMessage, 0)
claudeResponse := &dto.ClaudeResponse{
Id: openAIResponse.Id,
Type: "message",
Role: "assistant",
Model: openAIResponse.Model,
}
for _, choice := range openAIResponse.Choices {
stopReason = stopReasonOpenAI2Claude(choice.FinishReason)
claudeContent := dto.ClaudeMediaMessage{}
if choice.FinishReason == "tool_calls" {
claudeContent.Type = "tool_use"
claudeContent.Id = choice.Message.ToolCallId
claudeContent.Name = choice.Message.ParseToolCalls()[0].Function.Name
var mapParams map[string]interface{}
if err := json.Unmarshal([]byte(choice.Message.ParseToolCalls()[0].Function.Arguments), &mapParams); err == nil {
claudeContent.Input = mapParams
} else {
claudeContent.Input = choice.Message.ParseToolCalls()[0].Function.Arguments
}
} else {
claudeContent.Type = "text"
claudeContent.SetText(choice.Message.StringContent())
}
contents = append(contents, claudeContent)
}
claudeResponse.Content = contents
claudeResponse.StopReason = stopReason
claudeResponse.Usage = &dto.ClaudeUsage{
InputTokens: openAIResponse.PromptTokens,
OutputTokens: openAIResponse.CompletionTokens,
}
return claudeResponse
}
func stopReasonOpenAI2Claude(reason string) string {
switch reason {
case "stop":
return "end_turn"
case "stop_sequence":
return "stop_sequence"
case "max_tokens":
return "max_tokens"
case "tool_calls":
return "tool_use"
default:
return reason
}
}
func toJSONString(v interface{}) string {
b, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(b)
}

View File

@@ -50,7 +50,30 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
return openaiErr
}
func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
text := err.Error()
lowerText := strings.ToLower(text)
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
claudeError := dto.ClaudeError{
Message: text,
Type: "new_api_error",
}
return &dto.ClaudeErrorWithStatusCode{
Error: claudeError,
StatusCode: statusCode,
}
}
func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
claudeErr := ClaudeErrorWrapper(err, code, statusCode)
claudeErr.LocalError = true
return claudeErr
}
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: dto.OpenAIError{
@@ -70,6 +93,11 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
var errResponse dto.GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
if showBodyWhenFail {
errWithStatusCode.Error.Message = string(responseBody)
} else {
errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
if errResponse.Error.Message != "" {

View File

@@ -53,3 +53,12 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
info["audio_completion_ratio"] = audioCompletionRatio
return info
}
func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} {
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
info["claude"] = true
info["cache_creation_tokens"] = cacheCreationTokens
info["cache_creation_ratio"] = cacheCreationRatio
return info
}

View File

@@ -194,6 +194,73 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatio
modelPrice := priceData.ModelPrice
cacheRatio := priceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
cacheCreationRatio := priceData.CacheCreationRatio
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
calculateQuota := 0.0
if !priceData.UsePrice {
calculateQuota = float64(promptTokens)
calculateQuota += float64(cacheTokens) * cacheRatio
calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
calculateQuota += float64(completionTokens) * completionRatio
calculateQuota = calculateQuota * groupRatio * modelRatio
} else {
calculateQuota = modelPrice * common.QuotaPerUnit * groupRatio
}
if modelRatio != 0 && calculateQuota <= 0 {
calculateQuota = 1
}
quota := int(calculateQuota)
totalTokens := promptTokens + completionTokens
var logContent string
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游出错)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
@@ -249,17 +316,18 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
logModel := relayInfo.OriginModelName
if extraContent != "" {
logContent += ", " + extraContent

View File

@@ -1,6 +1,7 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"image"
@@ -85,6 +86,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
}
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
if imageUrl == nil {
return 0, fmt.Errorf("image_url_is_nil")
}
baseTokens := 85
if model == "glm-4v" {
return 1047, nil
@@ -92,10 +96,10 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
if imageUrl.Detail == "low" {
return baseTokens, nil
}
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream {
return 256, nil
return 3 * baseTokens, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
@@ -125,18 +129,11 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
if err != nil {
return 0, err
}
imageUrl.MimeType = format
if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
}
//// TODO: 适配官方auto计费
//if config.Width < 512 && config.Height < 512 {
// if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
// // 如果图片尺寸小于512强制使用low
// imageUrl.Detail = "low"
// return 85, nil
// }
//}
shortSide := config.Width
otherSide := config.Height
@@ -192,6 +189,110 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
return tkm, nil
}
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
tkm := 0
// Count tokens in messages
msgTokens, err := CountTokenClaudeMessages(request.Messages, model, request.Stream)
if err != nil {
return 0, err
}
tkm += msgTokens
// Count tokens in system message
if request.System != "" {
systemTokens, err := CountTokenInput(request.System, model)
if err != nil {
return 0, err
}
tkm += systemTokens
}
if request.Tools != nil {
// check is array
if tools, ok := request.Tools.([]any); ok {
if len(tools) > 0 {
parsedTools, err1 := common.Any2Type[[]dto.Tool](request.Tools)
if err1 != nil {
return 0, fmt.Errorf("tools: Input should be a valid list: %v", err)
}
toolTokens, err2 := CountTokenClaudeTools(parsedTools, model)
if err2 != nil {
return 0, fmt.Errorf("tools: %v", err)
}
tkm += toolTokens
}
} else {
return 0, errors.New("tools: Input should be a valid list")
}
}
return tkm, nil
}
func CountTokenClaudeMessages(messages []dto.ClaudeMessage, model string, stream bool) (int, error) {
tokenEncoder := getTokenEncoder(model)
tokenNum := 0
for _, message := range messages {
// Count tokens for role
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.IsStringContent() {
tokenNum += getTokenNum(tokenEncoder, message.GetStringContent())
} else {
content, err := message.ParseContent()
if err != nil {
return 0, err
}
for _, mediaMessage := range content {
switch mediaMessage.Type {
case "text":
tokenNum += getTokenNum(tokenEncoder, mediaMessage.GetText())
case "image":
//imageTokenNum, err := getClaudeImageToken(mediaMsg.Source, model, stream)
//if err != nil {
// return 0, err
//}
tokenNum += 1000
case "tool_use":
tokenNum += getTokenNum(tokenEncoder, mediaMessage.Name)
inputJSON, _ := json.Marshal(mediaMessage.Input)
tokenNum += getTokenNum(tokenEncoder, string(inputJSON))
case "tool_result":
contentJSON, _ := json.Marshal(mediaMessage.Content)
tokenNum += getTokenNum(tokenEncoder, string(contentJSON))
}
}
}
}
// Add a constant for message formatting (this may need adjustment based on Claude's exact formatting)
tokenNum += len(messages) * 2 // Assuming 2 tokens per message for formatting
return tokenNum, nil
}
func CountTokenClaudeTools(tools []dto.Tool, model string) (int, error) {
tokenEncoder := getTokenEncoder(model)
tokenNum := 0
for _, tool := range tools {
tokenNum += getTokenNum(tokenEncoder, tool.Name)
tokenNum += getTokenNum(tokenEncoder, tool.Description)
schemaJSON, err := json.Marshal(tool.InputSchema)
if err != nil {
return 0, errors.New(fmt.Sprintf("marshal_tool_schema_fail: %s", err.Error()))
}
tokenNum += getTokenNum(tokenEncoder, string(schemaJSON))
}
// Add a constant for tool formatting (this may need adjustment based on Claude's exact formatting)
tokenNum += len(tools) * 3 // Assuming 3 tokens per tool for formatting
return tokenNum, nil
}
func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
audioToken := 0
textToken := 0
@@ -287,8 +388,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
imageUrl := m.GetImageMedia()
imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
if err != nil {
return 0, err
}

View File

@@ -0,0 +1,26 @@
package model_setting
import (
"one-api/setting/config"
)
type GlobalSettings struct {
PassThroughRequestEnabled bool `json:"pass_through_request_enabled"`
}
// 默认配置
var defaultOpenaiSettings = GlobalSettings{
PassThroughRequestEnabled: false,
}
// 全局实例
var globalSettings = defaultOpenaiSettings
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("global", &globalSettings)
}
func GetGlobalSettings() *GlobalSettings {
return &globalSettings
}

Some files were not shown because too many files have changed in this diff Show More