mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-31 17:33:23 +00:00
Compare commits
253 Commits
v0.8.5.0
...
refactor_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9e03e6172 | ||
|
|
cb16bf552e | ||
|
|
98952198bb | ||
|
|
338e914a60 | ||
|
|
78fb457765 | ||
|
|
8759ef012f | ||
|
|
f8d67a62a2 | ||
|
|
efb98854b2 | ||
|
|
7b29f429ee | ||
|
|
265c7d93a2 | ||
|
|
ce57ad3570 | ||
|
|
0e6b608f91 | ||
|
|
f1856fe4d2 | ||
|
|
870cdd5a56 | ||
|
|
9282f1d893 | ||
|
|
9546a47f2b | ||
|
|
f0f277dc2a | ||
|
|
b695e67154 | ||
|
|
fa2cd85007 | ||
|
|
8073cbd96a | ||
|
|
5eba2f1d61 | ||
|
|
5ec421d8e6 | ||
|
|
1e25bf700d | ||
|
|
30fb349d91 | ||
|
|
d40fb68500 | ||
|
|
3049ad47e5 | ||
|
|
8945a3a2dd | ||
|
|
d191eef657 | ||
|
|
6ac7878863 | ||
|
|
c0a23ffa62 | ||
|
|
7d691f362d | ||
|
|
bf577b8937 | ||
|
|
819290c9b8 | ||
|
|
22e8b46159 | ||
|
|
76b8cc1168 | ||
|
|
fce07325b9 | ||
|
|
123862d41c | ||
|
|
7e298f8ad1 | ||
|
|
34aca14858 | ||
|
|
6b1f94348a | ||
|
|
4322037639 | ||
|
|
ae11f88595 | ||
|
|
389a4c3e4c | ||
|
|
efb691e6c2 | ||
|
|
53e3b35437 | ||
|
|
eb265a55e1 | ||
|
|
950f7d214f | ||
|
|
6bd2316d9c | ||
|
|
660180ea1b | ||
|
|
efc8457770 | ||
|
|
9b8b982d8a | ||
|
|
e6949e611a | ||
|
|
cffade7210 | ||
|
|
6b9237f868 | ||
|
|
1f4cf07b63 | ||
|
|
59a1f4c900 | ||
|
|
0a04a76c71 | ||
|
|
9e6bc518cc | ||
|
|
bfb6fbbac9 | ||
|
|
9c08d8cf20 | ||
|
|
281054ff4c | ||
|
|
3002659f47 | ||
|
|
647f8d7958 | ||
|
|
5d289d38ba | ||
|
|
05ea0dd54f | ||
|
|
1dad04ec09 | ||
|
|
2171117c53 | ||
|
|
d389befc9e | ||
|
|
3ced5ff144 | ||
|
|
38d3ab5acf | ||
|
|
ab32e15a86 | ||
|
|
25e17b95d5 | ||
|
|
d07224e658 | ||
|
|
aa15d45a3d | ||
|
|
c6c68da0b5 | ||
|
|
1a0aac81df | ||
|
|
39cb45c11c | ||
|
|
05d9aa53ef | ||
|
|
86f374df58 | ||
|
|
6935260bf0 | ||
|
|
f0d888729b | ||
|
|
6d7d4292ef | ||
|
|
fcefac9dbe | ||
|
|
ad5f731b20 | ||
|
|
76da067d40 | ||
|
|
0689670698 | ||
|
|
5a6f32c392 | ||
|
|
d6276c4692 | ||
|
|
29a44eb7ae | ||
|
|
048a625181 | ||
|
|
64782027c4 | ||
|
|
277645db50 | ||
|
|
3f53e4f53e | ||
|
|
0c5d4ca0a7 | ||
|
|
44495b153a | ||
|
|
de6e551cdb | ||
|
|
aeb393e391 | ||
|
|
db1b11deaf | ||
|
|
5a5e8ce652 | ||
|
|
6c31151430 | ||
|
|
a8ba2eba33 | ||
|
|
c974b1053c | ||
|
|
1ab75b8a92 | ||
|
|
75e3959474 | ||
|
|
bc371778b6 | ||
|
|
cd2870aebc | ||
|
|
7c72545217 | ||
|
|
2591ca3d60 | ||
|
|
c28190316f | ||
|
|
ffc22b8dac | ||
|
|
5367015a31 | ||
|
|
75c71c397e | ||
|
|
6192aebe66 | ||
|
|
a85a594597 | ||
|
|
014c9450ba | ||
|
|
63640f65e8 | ||
|
|
fd040988a3 | ||
|
|
f7c3b043b5 | ||
|
|
93e7675bc3 | ||
|
|
d7c97d4d34 | ||
|
|
dce794dbf7 | ||
|
|
093d86040f | ||
|
|
39617bc8c6 | ||
|
|
7da224ba92 | ||
|
|
df862732df | ||
|
|
fd4447f60a | ||
|
|
ea79d59aa0 | ||
|
|
41b0cf406c | ||
|
|
ef32cc8e0a | ||
|
|
ee8956b0e9 | ||
|
|
5ad9f8d931 | ||
|
|
ea379e1d0e | ||
|
|
b842baf21f | ||
|
|
58c9c7d5dd | ||
|
|
384fadf227 | ||
|
|
e4def0625b | ||
|
|
44d20de251 | ||
|
|
7ea33c2ddf | ||
|
|
b43423bffc | ||
|
|
cf4700a35c | ||
|
|
6bb552128c | ||
|
|
50b4fc06f8 | ||
|
|
f7f1be9df2 | ||
|
|
59574dc80f | ||
|
|
7577ec1ac4 | ||
|
|
d487be0029 | ||
|
|
83a3872b97 | ||
|
|
1ad2f63f85 | ||
|
|
fcaa8317e4 | ||
|
|
ccda14255a | ||
|
|
8d66828229 | ||
|
|
4ebf9e35e1 | ||
|
|
2902d6c7c2 | ||
|
|
01ef1fe4e4 | ||
|
|
c3d2d07b68 | ||
|
|
18417bacb3 | ||
|
|
8ec18dd21b | ||
|
|
edaff1c689 | ||
|
|
9c3a13cb23 | ||
|
|
0b326e7af4 | ||
|
|
1a1ff836b5 | ||
|
|
34fed74f64 | ||
|
|
f89b29928c | ||
|
|
2c6d4460c3 | ||
|
|
7afd3f97ee | ||
|
|
0708452939 | ||
|
|
a9e5d99ea3 | ||
|
|
a56d9ea98b | ||
|
|
f5e80af0b3 | ||
|
|
a1a7ddbc83 | ||
|
|
8b209d8926 | ||
|
|
9344cab59a | ||
|
|
03468e05e4 | ||
|
|
11792ba1a4 | ||
|
|
5baaa06896 | ||
|
|
d3286893c4 | ||
|
|
b087b20bac | ||
|
|
6a5a839d4d | ||
|
|
5d8a0952b4 | ||
|
|
bd08ecc1e0 | ||
|
|
e4f61c1084 | ||
|
|
a38215478f | ||
|
|
c192d07a04 | ||
|
|
098880b796 | ||
|
|
150c506ece | ||
|
|
f978d8224e | ||
|
|
ab59887933 | ||
|
|
458472f3e2 | ||
|
|
a9f98c5d39 | ||
|
|
2b7dff2d94 | ||
|
|
58752d2dcf | ||
|
|
67546f4b2a | ||
|
|
8e9dae7b5f | ||
|
|
fb4ff63bad | ||
|
|
1fed1ee567 | ||
|
|
02571c20ff | ||
|
|
8201daa4b4 | ||
|
|
5b54624cd5 | ||
|
|
db737567fb | ||
|
|
5c3898d13e | ||
|
|
2fdb2be6d0 | ||
|
|
ab78efc815 | ||
|
|
fcf97d1796 | ||
|
|
e85cc6acbe | ||
|
|
da002e6ca9 | ||
|
|
070e7b6911 | ||
|
|
d5a3eb7d04 | ||
|
|
616e6953cc | ||
|
|
b7c77777a5 | ||
|
|
8a79de333a | ||
|
|
a87d4271d3 | ||
|
|
7975cdf3bf | ||
|
|
b2badad554 | ||
|
|
133d8c9f77 | ||
|
|
9708d645d3 | ||
|
|
0bca4d3efc | ||
|
|
7572e791f6 | ||
|
|
16c63b3be9 | ||
|
|
37fbcb7950 | ||
|
|
a180d13182 | ||
|
|
a6363a502a | ||
|
|
81bc096872 | ||
|
|
edcdb378fd | ||
|
|
4447e51588 | ||
|
|
fb8aac650f | ||
|
|
ba6b0637cc | ||
|
|
3502730dfc | ||
|
|
b95c5bb8f4 | ||
|
|
f35784aa97 | ||
|
|
3746482e8c | ||
|
|
5ed4b60b8f | ||
|
|
547da2da60 | ||
|
|
f88ed4dd5c | ||
|
|
87fc681df3 | ||
|
|
a39b2f5aa7 | ||
|
|
4a8b7bfa37 | ||
|
|
296da5dbcc | ||
|
|
7403df7e9c | ||
|
|
617c8e8f4f | ||
|
|
aa793088ed | ||
|
|
0089157b83 | ||
|
|
1ec2bbd533 | ||
|
|
d67d5d8006 | ||
|
|
c4f25a77d1 | ||
|
|
52763c09f2 | ||
|
|
e77555a04f | ||
|
|
4a313a5f93 | ||
|
|
3d9587f128 | ||
|
|
66778efcc5 | ||
|
|
6be78ff283 | ||
|
|
5281f2ba64 | ||
|
|
69420f713f | ||
|
|
bc322ddac4 |
12
.env.example
12
.env.example
@@ -7,6 +7,8 @@
|
||||
# 调试相关配置
|
||||
# 启用pprof
|
||||
# ENABLE_PPROF=true
|
||||
# 启用调试模式
|
||||
# DEBUG=true
|
||||
|
||||
# 数据库相关配置
|
||||
# 数据库连接字符串
|
||||
@@ -41,6 +43,14 @@
|
||||
# 更新任务启用
|
||||
# UPDATE_TASK=true
|
||||
|
||||
# 对话超时设置
|
||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||
# RELAY_TIMEOUT=0
|
||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||
# STREAMING_TIMEOUT=120
|
||||
|
||||
# Gemini 识别图片 最大图片数量
|
||||
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||
|
||||
# 会话密钥
|
||||
# SESSION_SECRET=random_string
|
||||
@@ -58,8 +68,6 @@
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
# 设置流式一次回复的超时时间
|
||||
# STREAMING_TIMEOUT=90
|
||||
|
||||
|
||||
# 节点类型
|
||||
|
||||
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
### PR 类型
|
||||
|
||||
- [ ] Bug 修复
|
||||
- [ ] 新功能
|
||||
- [ ] 文档更新
|
||||
- [ ] 其他
|
||||
|
||||
### PR 是否包含破坏性更新?
|
||||
|
||||
- [ ] 是
|
||||
- [ ] 否
|
||||
|
||||
### PR 描述
|
||||
|
||||
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||
|
||||
### **重要提示**
|
||||
|
||||
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||
1
.github/workflows/macos-release.yml
vendored
1
.github/workflows/macos-release.yml
vendored
@@ -26,6 +26,7 @@ jobs:
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
|
||||
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update \
|
||||
&& apk upgrade \
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||
&& update-ca-certificates
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
|
||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
|
||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||
|
||||
@@ -27,9 +27,6 @@
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
<a href="https://coderabbit.ai">
|
||||
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -103,7 +100,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
|
||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||
|
||||
71
common/api_type.go
Normal file
71
common/api_type.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case constant.ChannelTypeOpenAI:
|
||||
apiType = constant.APITypeOpenAI
|
||||
case constant.ChannelTypeAnthropic:
|
||||
apiType = constant.APITypeAnthropic
|
||||
case constant.ChannelTypeBaidu:
|
||||
apiType = constant.APITypeBaidu
|
||||
case constant.ChannelTypePaLM:
|
||||
apiType = constant.APITypePaLM
|
||||
case constant.ChannelTypeZhipu:
|
||||
apiType = constant.APITypeZhipu
|
||||
case constant.ChannelTypeAli:
|
||||
apiType = constant.APITypeAli
|
||||
case constant.ChannelTypeXunfei:
|
||||
apiType = constant.APITypeXunfei
|
||||
case constant.ChannelTypeAIProxyLibrary:
|
||||
apiType = constant.APITypeAIProxyLibrary
|
||||
case constant.ChannelTypeTencent:
|
||||
apiType = constant.APITypeTencent
|
||||
case constant.ChannelTypeGemini:
|
||||
apiType = constant.APITypeGemini
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
apiType = constant.APITypeZhipuV4
|
||||
case constant.ChannelTypeOllama:
|
||||
apiType = constant.APITypeOllama
|
||||
case constant.ChannelTypePerplexity:
|
||||
apiType = constant.APITypePerplexity
|
||||
case constant.ChannelTypeAws:
|
||||
apiType = constant.APITypeAws
|
||||
case constant.ChannelTypeCohere:
|
||||
apiType = constant.APITypeCohere
|
||||
case constant.ChannelTypeDify:
|
||||
apiType = constant.APITypeDify
|
||||
case constant.ChannelTypeJina:
|
||||
apiType = constant.APITypeJina
|
||||
case constant.ChannelCloudflare:
|
||||
apiType = constant.APITypeCloudflare
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
apiType = constant.APITypeSiliconFlow
|
||||
case constant.ChannelTypeVertexAi:
|
||||
apiType = constant.APITypeVertexAi
|
||||
case constant.ChannelTypeMistral:
|
||||
apiType = constant.APITypeMistral
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
apiType = constant.APITypeDeepSeek
|
||||
case constant.ChannelTypeMokaAI:
|
||||
apiType = constant.APITypeMokaAI
|
||||
case constant.ChannelTypeVolcEngine:
|
||||
apiType = constant.APITypeVolcEngine
|
||||
case constant.ChannelTypeBaiduV2:
|
||||
apiType = constant.APITypeBaiduV2
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
apiType = constant.APITypeOpenRouter
|
||||
case constant.ChannelTypeXinference:
|
||||
apiType = constant.APITypeXinference
|
||||
case constant.ChannelTypeXai:
|
||||
apiType = constant.APITypeXai
|
||||
case constant.ChannelTypeCoze:
|
||||
apiType = constant.APITypeCoze
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
@@ -193,107 +193,3 @@ const (
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
}
|
||||
|
||||
41
common/endpoint_type.go
Normal file
41
common/endpoint_type.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||
var endpointTypes []constant.EndpointType
|
||||
switch channelType {
|
||||
case constant.ChannelTypeJina:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||
//case constant.ChannelTypeSunoAPI:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||
//case constant.ChannelTypeKling:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||
//case constant.ChannelTypeJimeng:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||
case constant.ChannelTypeAws:
|
||||
fallthrough
|
||||
case constant.ChannelTypeAnthropic:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
fallthrough
|
||||
case constant.ChannelTypeGemini:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
} else {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
}
|
||||
}
|
||||
if IsImageGenerationModel(modelName) {
|
||||
// add to first
|
||||
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||
}
|
||||
return endpointTypes
|
||||
}
|
||||
@@ -2,10 +2,11 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
@@ -31,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
err = Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
@@ -43,3 +44,45 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||
c.Set(string(key), value)
|
||||
}
|
||||
|
||||
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||
return c.Get(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||
return c.GetString(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||
return c.GetInt(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||
return c.GetBool(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||
return c.GetStringSlice(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||
return c.GetStringMap(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||
return c.GetTime(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) {
|
||||
if value, ok := c.Get(string(key)); ok {
|
||||
if v, ok := value.(T); ok {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
var t T
|
||||
return t, false
|
||||
}
|
||||
|
||||
57
common/http.go
Normal file
57
common/http.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||
if httpResponse == nil || httpResponse.Body == nil {
|
||||
return
|
||||
}
|
||||
err := httpResponse.Body.Close()
|
||||
if err != nil {
|
||||
SysError("failed to close response body: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||
if c.Writer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
body := io.NopCloser(bytes.NewBuffer(data))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
if src != nil {
|
||||
for k, v := range src.Header {
|
||||
// avoid setting Content-Length
|
||||
if k == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
// set Content-Length header manually BEFORE calling WriteHeader
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
|
||||
// Write header with status code (this sends the headers)
|
||||
if src != nil {
|
||||
c.Writer.WriteHeader(src.StatusCode)
|
||||
} else {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Writer, body)
|
||||
if err != nil {
|
||||
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -24,7 +25,7 @@ func printHelp() {
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func LoadEnv() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
@@ -95,4 +96,25 @@ func LoadEnv() {
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
}
|
||||
|
||||
@@ -5,14 +5,18 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func DecodeJson(data []byte, v any) error {
|
||||
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
func Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func DecodeJsonStr(data string, v any) error {
|
||||
return DecodeJson(StringToByteSlice(data), v)
|
||||
func UnmarshalJsonStr(data string, v any) error {
|
||||
return json.Unmarshal(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
func Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
42
common/model.go
Normal file
42
common/model.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package common
|
||||
|
||||
import "strings"
|
||||
|
||||
var (
|
||||
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||
OpenAIResponseOnlyModels = []string{
|
||||
"o3-pro",
|
||||
"o3-deep-research",
|
||||
"o4-mini-deep-research",
|
||||
}
|
||||
ImageGenerationModels = []string{
|
||||
"dall-e-3",
|
||||
"dall-e-2",
|
||||
"gpt-image-1",
|
||||
"prefix:imagen-",
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
for _, m := range OpenAIResponseOnlyModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsImageGenerationModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range ImageGenerationModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
62
common/page_info.go
Normal file
62
common/page_info.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type PageInfo struct {
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
StartTimestamp int64 `json:"start_timestamp"` // 秒级
|
||||
EndTimestamp int64 `json:"end_timestamp"` // 秒级
|
||||
|
||||
Total int `json:"total"` // 总条数,后设置
|
||||
Items any `json:"items"` // 数据,后设置
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetStartIdx() int {
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetEndIdx() int {
|
||||
return p.Page * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPageSize() int {
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPage() int {
|
||||
return p.Page
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetTotal(total int) {
|
||||
p.Total = total
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetItems(items any) {
|
||||
p.Items = items
|
||||
}
|
||||
|
||||
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||
pageInfo := &PageInfo{}
|
||||
err := c.BindQuery(pageInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pageInfo.Page < 1 {
|
||||
// 兼容
|
||||
page, _ := strconv.Atoi(c.Query("p"))
|
||||
if page != 0 {
|
||||
pageInfo.Page = page
|
||||
} else {
|
||||
pageInfo.Page = 1
|
||||
}
|
||||
}
|
||||
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
}
|
||||
return pageInfo, nil
|
||||
}
|
||||
@@ -16,6 +16,10 @@ import (
|
||||
var RDB *redis.Client
|
||||
var RedisEnabled = true
|
||||
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return SyncFrequency
|
||||
}
|
||||
|
||||
// InitRedisClient This function is called after init()
|
||||
func InitRedisClient() (err error) {
|
||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
@@ -31,16 +32,30 @@ func MapToJsonStr(m map[string]interface{}) string {
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
func StrToMap(str string) map[string]interface{} {
|
||||
func StrToMap(str string) (map[string]interface{}, error) {
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(str), &m)
|
||||
err := Unmarshal([]byte(str), &m)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return m
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func IsJsonStr(str string) bool {
|
||||
func StrToJsonArray(str string) ([]interface{}, error) {
|
||||
var js []interface{}
|
||||
err := json.Unmarshal([]byte(str), &js)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func IsJsonArray(str string) bool {
|
||||
var js []interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
|
||||
func IsJsonObject(str string) bool {
|
||||
var js map[string]interface{}
|
||||
return json.Unmarshal([]byte(str), &js) == nil
|
||||
}
|
||||
@@ -68,3 +83,15 @@ func StringToByteSlice(s string) []byte {
|
||||
tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]}
|
||||
return *(*[]byte)(unsafe.Pointer(&tmp2))
|
||||
}
|
||||
|
||||
func EncodeBase64(str string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(str))
|
||||
}
|
||||
|
||||
func GetJsonString(data any) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
b, _ := json.Marshal(data)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||
func BuildURL(base string, endpoint string) string {
|
||||
u, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
end := endpoint
|
||||
if end == "" {
|
||||
end = "/"
|
||||
}
|
||||
ref, err := url.Parse(end)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
return u.ResolveReference(ref).String()
|
||||
}
|
||||
|
||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# constant 包 (`/constant`)
|
||||
|
||||
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||
|
||||
## 当前文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|----------------------|---------------------------------------------------------------------|
|
||||
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||
|
||||
## 使用约定
|
||||
|
||||
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||
|
||||
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||
34
constant/api_type.go
Normal file
34
constant/api_type.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -1,12 +1,5 @@
|
||||
package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
// 使用函数来避免初始化顺序带来的赋值问题
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return common.SyncFrequency
|
||||
}
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
|
||||
109
constant/channel.go
Normal file
109
constant/channel.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
ForceFormat = "force_format" // ForceFormat 强制格式化为OpenAI格式
|
||||
ChanelSettingProxy = "proxy" // Proxy 代理
|
||||
ChannelSettingThinkingToContent = "thinking_to_content" // ThinkingToContent
|
||||
)
|
||||
@@ -1,10 +1,42 @@
|
||||
package constant
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
ContextKeyUserSetting = "user_setting"
|
||||
ContextKeyUserQuota = "user_quota"
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
/* token related keys */
|
||||
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||
ContextKeyTokenKey ContextKey = "token_key"
|
||||
ContextKeyTokenId ContextKey = "token_id"
|
||||
ContextKeyTokenGroup ContextKey = "token_group"
|
||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelName ContextKey = "channel_name"
|
||||
ContextKeyChannelCreateTime ContextKey = "channel_create_name"
|
||||
ContextKeyChannelBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyChannelParamOverride ContextKey = "param_override"
|
||||
ContextKeyChannelOrganization ContextKey = "channel_organization"
|
||||
ContextKeyChannelAutoBan ContextKey = "auto_ban"
|
||||
ContextKeyChannelModelMapping ContextKey = "model_mapping"
|
||||
ContextKeyChannelStatusCodeMapping ContextKey = "status_code_mapping"
|
||||
ContextKeyChannelIsMultiKey ContextKey = "channel_is_multi_key"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
ContextKeyUserSetting ContextKey = "user_setting"
|
||||
ContextKeyUserQuota ContextKey = "user_quota"
|
||||
ContextKeyUserStatus ContextKey = "user_status"
|
||||
ContextKeyUserEmail ContextKey = "user_email"
|
||||
ContextKeyUserGroup ContextKey = "user_group"
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
)
|
||||
|
||||
16
constant/endpoint_type.go
Normal file
16
constant/endpoint_type.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
type EndpointType string
|
||||
|
||||
const (
|
||||
EndpointTypeOpenAI EndpointType = "openai"
|
||||
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||
EndpointTypeGemini EndpointType = "gemini"
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||
)
|
||||
@@ -1,9 +1,5 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
//}
|
||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||
// parts := strings.Split(pair, ":")
|
||||
// if len(parts) == 2 {
|
||||
// GeminiModelMap[parts[0]] = parts[1]
|
||||
// } else {
|
||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ const (
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
MjActionUpload = "UPLOAD"
|
||||
MjActionVideo = "VIDEO"
|
||||
MjActionEdits = "EDITS"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
"mj_upload": MjActionUpload,
|
||||
"mj_video": MjActionVideo,
|
||||
"mj_edits": MjActionEdits,
|
||||
}
|
||||
|
||||
8
constant/multi_key_mode.go
Normal file
8
constant/multi_key_mode.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package constant
|
||||
|
||||
type MultiKeyMode string
|
||||
|
||||
const (
|
||||
MultiKeyModeRandom MultiKeyMode = "random" // 随机
|
||||
MultiKeyModePolling MultiKeyMode = "polling" // 轮询
|
||||
)
|
||||
@@ -5,11 +5,16 @@ type TaskPlatform string
|
||||
const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
||||
)
|
||||
|
||||
const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
UserSettingNotifyType = "notify_type" // QuotaWarningType 额度预警类型
|
||||
UserSettingQuotaWarningThreshold = "quota_warning_threshold" // QuotaWarningThreshold 额度预警阈值
|
||||
UserSettingWebhookUrl = "webhook_url" // WebhookUrl webhook地址
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
|
||||
)
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
)
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/shopspring/decimal"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -304,34 +307,70 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
type MoonshotBalanceData struct {
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
VoucherBalance float64 `json:"voucher_balance"`
|
||||
CashBalance float64 `json:"cash_balance"`
|
||||
}
|
||||
|
||||
type MoonshotBalanceResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data MoonshotBalanceData `json:"data"`
|
||||
Scode string `json:"scode"`
|
||||
Status bool `json:"status"`
|
||||
}
|
||||
|
||||
response := MoonshotBalanceResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Status || response.Code != 0 {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
case constant.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
case constant.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
case constant.ChannelTypeAPI2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case common.ChannelTypeAIGC2D:
|
||||
case constant.ChannelTypeAIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
case common.ChannelTypeDeepSeek:
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
|
||||
@@ -11,14 +11,15 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -29,17 +30,23 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, newAPIError *types.NewAPIError) {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -50,7 +57,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
@@ -90,16 +97,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info)
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return err, types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel)
|
||||
@@ -110,45 +117,45 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return err, types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
adaptor.Init(info)
|
||||
|
||||
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return err, types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return err, types.NewError(err, types.ErrorCodeJsonMarshalFailed)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return err, types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err
|
||||
return err, types.NewError(err, types.ErrorCodeBadResponse)
|
||||
}
|
||||
}
|
||||
usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), respErr
|
||||
return respErr, respErr
|
||||
}
|
||||
if usageA == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
return errors.New("usage is nil"), types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
usage := usageA.(*dto.Usage)
|
||||
result := w.Result()
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
return err, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
|
||||
}
|
||||
info.PromptTokens = usage.PromptTokens
|
||||
|
||||
@@ -167,8 +174,19 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
||||
ChannelId: channel.Id,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
ModelName: info.OriginModelName,
|
||||
TokenName: "模型测试",
|
||||
Quota: quota,
|
||||
Content: "模型测试",
|
||||
UseTimeSeconds: int(consumedTime),
|
||||
IsStream: false,
|
||||
Group: info.UsingGroup,
|
||||
Other: other,
|
||||
})
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@@ -196,7 +214,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
testRequest.MaxTokens = 50
|
||||
}
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 300
|
||||
testRequest.MaxTokens = 3000
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
@@ -229,15 +247,15 @@ func TestChannel(c *gin.Context) {
|
||||
}
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, testModel)
|
||||
_, newAPIError := testChannel(channel, testModel)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
if err != nil {
|
||||
if newAPIError != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
"message": newAPIError.Error(),
|
||||
"time": consumedTime,
|
||||
})
|
||||
return
|
||||
@@ -281,17 +299,15 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiWithStatusErr := testChannel(channel, "")
|
||||
err, newAPIError := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
shouldBanChannel := false
|
||||
|
||||
// request error disables the channel
|
||||
if openaiWithStatusErr != nil {
|
||||
oaiErr := openaiWithStatusErr.Error
|
||||
err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
|
||||
if err != nil {
|
||||
shouldBanChannel = service.ShouldDisableChannel(channel.Type, newAPIError)
|
||||
}
|
||||
|
||||
if milliseconds > disableThreshold {
|
||||
@@ -305,7 +321,7 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
|
||||
// enable channel
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) {
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, newAPIError, channel.Status) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
|
||||
@@ -337,6 +353,10 @@ func TestAllChannels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -40,6 +41,17 @@ type OpenAIModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func parseStatusFilter(statusParam string) int {
|
||||
switch strings.ToLower(statusParam) {
|
||||
case "enabled", "1":
|
||||
return common.ChannelStatusEnabled
|
||||
case "disabled", "0":
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
@@ -52,44 +64,100 @@ func GetAllChannels(c *gin.Context) {
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
statusParam := c.Query("status")
|
||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
// tag 分页:先分页 tag,再取各 tag 下 channels
|
||||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
}
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
filtered := make([]*model.Channel, 0)
|
||||
for _, ch := range tagChannels {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if typeFilter >= 0 && ch.Type != typeFilter {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = append(channelData, filtered...)
|
||||
}
|
||||
// 计算 tag 总数用于分页
|
||||
total, _ = model.CountAllTags()
|
||||
} else {
|
||||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
||||
baseQuery := model.DB.Model(&model.Channel{})
|
||||
if typeFilter >= 0 {
|
||||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||||
}
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
|
||||
baseQuery.Count(&total)
|
||||
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountAllChannels()
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
var results []struct {
|
||||
Type int64
|
||||
Count int64
|
||||
}
|
||||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
typeCounts[r.Type] = r.Count
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -114,22 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
@@ -153,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
var ids []string
|
||||
for _, model := range result.Data {
|
||||
id := model.ID
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
id = strings.TrimPrefix(id, "models/")
|
||||
}
|
||||
ids = append(ids, id)
|
||||
@@ -167,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
func FixChannelsAbilities(c *gin.Context) {
|
||||
count, err := model.FixAbility()
|
||||
success, fails, err := model.FixAbility()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -178,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
"data": gin.H{
|
||||
"success": success,
|
||||
"fails": fails,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -186,6 +250,8 @@ func SearchChannels(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
modelKeyword := c.Query("model")
|
||||
statusParam := c.Query("status")
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
channelData := make([]*model.Channel, 0)
|
||||
@@ -217,10 +283,74 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
typeParam := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeParam != "" {
|
||||
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||||
typeFilter = tp
|
||||
}
|
||||
}
|
||||
|
||||
if typeFilter >= 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if ch.Type == typeFilter {
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
total := len(channelData)
|
||||
startIdx := (page - 1) * pageSize
|
||||
if startIdx > total {
|
||||
startIdx = total
|
||||
}
|
||||
endIdx := startIdx + pageSize
|
||||
if endIdx > total {
|
||||
endIdx = total
|
||||
}
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": pagedData,
|
||||
"total": total,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -250,9 +380,47 @@ func GetChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type AddChannelRequest struct {
|
||||
Mode string `json:"mode"`
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
Channel *model.Channel `json:"channel"`
|
||||
}
|
||||
|
||||
func getVertexArrayKeys(keys string) ([]string, error) {
|
||||
if keys == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var keyArray []interface{}
|
||||
err := common.Unmarshal([]byte(keys), &keyArray)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
|
||||
}
|
||||
cleanKeys := make([]string, 0, len(keyArray))
|
||||
for _, key := range keyArray {
|
||||
var keyStr string
|
||||
switch v := key.(type) {
|
||||
case string:
|
||||
keyStr = strings.TrimSpace(v)
|
||||
default:
|
||||
bytes, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
|
||||
}
|
||||
keyStr = string(bytes)
|
||||
}
|
||||
if keyStr != "" {
|
||||
cleanKeys = append(cleanKeys, keyStr)
|
||||
}
|
||||
}
|
||||
if len(cleanKeys) == 0 {
|
||||
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
|
||||
}
|
||||
return cleanKeys, nil
|
||||
}
|
||||
|
||||
func AddChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
addChannelRequest := AddChannelRequest{}
|
||||
err := c.ShouldBindJSON(&addChannelRequest)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -260,49 +428,120 @@ func AddChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
|
||||
err = addChannelRequest.Channel.ValidateSettings()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "channel setting 格式错误:" + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if addChannelRequest.Channel == nil || addChannelRequest.Channel.Key == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "channel cannot be empty",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the length of the model name
|
||||
for _, m := range addChannelRequest.Channel.GetModels() {
|
||||
if len(m) > 255 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("模型名称过长: %s", m),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区不能为空",
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
regionMap, err := common.StrToMap(addChannelRequest.Channel.Other)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||
})
|
||||
return
|
||||
}
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
keys = []string{channel.Key}
|
||||
}
|
||||
|
||||
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
|
||||
keys := make([]string, 0)
|
||||
switch addChannelRequest.Mode {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
addChannelRequest.Channel.Key = strings.Join(array, "\n")
|
||||
} else {
|
||||
cleanKeys := make([]string, 0)
|
||||
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
cleanKeys = append(cleanKeys, key)
|
||||
}
|
||||
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
|
||||
}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
|
||||
}
|
||||
case "single":
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "不支持的添加模式",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channels := make([]model.Channel, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
localChannel := channel
|
||||
localChannel := addChannelRequest.Channel
|
||||
localChannel.Key = key
|
||||
// Validate the length of the model name
|
||||
models := strings.Split(localChannel.Models, ",")
|
||||
for _, model := range models {
|
||||
if len(model) > 255 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("模型名称过长: %s", model),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
channels = append(channels, localChannel)
|
||||
channels = append(channels, *localChannel)
|
||||
}
|
||||
err = model.BatchInsertChannels(channels)
|
||||
if err != nil {
|
||||
@@ -487,7 +726,15 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
err = channel.ValidateSettings()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "channel setting 格式错误:" + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -495,16 +742,20 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
} else {
|
||||
if common.IsJsonStr(channel.Other) {
|
||||
// must have default
|
||||
regionMap := common.StrToMap(channel.Other)
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
regionMap, err := common.StrToMap(channel.Other)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}",
|
||||
})
|
||||
return
|
||||
}
|
||||
if regionMap["default"] == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "部署地区必须包含default字段",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -516,6 +767,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.Key = ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -541,7 +793,7 @@ func FetchModels(c *gin.Context) {
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = common.ChannelBaseURLs[req.Type]
|
||||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -25,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
|
||||
@@ -76,6 +76,7 @@ func GetStatus(c *gin.Context) {
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
|
||||
@@ -2,6 +2,8 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -13,10 +15,7 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -25,30 +24,10 @@ var openAIModels []dto.OpenAIModels
|
||||
var openAIModelsMap map[string]dto.OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func getPermission() []dto.OpenAIModelPermission {
|
||||
var permission []dto.OpenAIModelPermission
|
||||
permission = append(permission, dto.OpenAIModelPermission{
|
||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||
Object: "model_permission",
|
||||
Created: 1626777600,
|
||||
AllowCreateEngine: true,
|
||||
AllowSampling: true,
|
||||
AllowLogprobs: true,
|
||||
AllowSearchIndices: false,
|
||||
AllowView: true,
|
||||
AllowFineTuning: false,
|
||||
Organization: "*",
|
||||
Group: nil,
|
||||
IsBlocking: false,
|
||||
})
|
||||
return permission
|
||||
}
|
||||
|
||||
func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
permission := getPermission()
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
@@ -56,69 +35,51 @@ func init() {
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range minimax.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||
@@ -126,9 +87,9 @@ func init() {
|
||||
openAIModelsMap[aiModel.Id] = aiModel
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||
apiType, success := common.ChannelType2APIType(i)
|
||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
@@ -136,15 +97,17 @@ func init() {
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||
return m.Id
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
permission := getPermission()
|
||||
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -152,23 +115,22 @@ func ListModels(c *gin.Context) {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: allowModel,
|
||||
Parent: nil,
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId, true)
|
||||
userGroup, err := model.GetUserGroup(userId, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -177,14 +139,14 @@ func ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := c.GetString("token_group")
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
@@ -192,20 +154,19 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
for _, modelName := range models {
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
@@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -3,45 +3,44 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
if newAPIError != nil {
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
@@ -52,19 +51,34 @@ func Playground(c *gin.Context) {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
|
||||
userId := c.GetInt("id")
|
||||
//c.Set("token_name", "playground-"+group)
|
||||
tempToken := &model.Token{
|
||||
UserId: userId,
|
||||
Name: fmt.Sprintf("playground-%s", group),
|
||||
Group: group,
|
||||
}
|
||||
_ = middleware.SetupContextForToken(c, tempToken)
|
||||
_, err = getChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package controller
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) {
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := setting.GetGroupGroupRatio(group, g)
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) {
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
)
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
@@ -8,23 +8,24 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
|
||||
var err *types.NewAPIError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err = relay.ImageHelper(c)
|
||||
@@ -55,43 +56,43 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
userGroup := c.GetString("group")
|
||||
channelId := c.GetInt("channel_id")
|
||||
other := make(map[string]interface{})
|
||||
other["error_type"] = err.Error.Type
|
||||
other["error_code"] = err.Error.Code
|
||||
other["error_type"] = err.ErrorType
|
||||
other["error_code"] = err.GetErrorCode()
|
||||
other["status_code"] = err.StatusCode
|
||||
other["channel_id"] = channelId
|
||||
other["channel_name"] = c.GetString("channel_name")
|
||||
other["channel_type"] = c.GetInt("channel_type")
|
||||
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error.Message, tokenId, 0, false, userGroup, other)
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = relayRequest(c, relayMode, channel)
|
||||
newAPIError = relayRequest(c, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -101,14 +102,14 @@ func Relay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", openaiErr.Error.Message))
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
if newAPIError != nil {
|
||||
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
common.LogError(c, fmt.Sprintf("origin 429 error: %s", newAPIError.Error()))
|
||||
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"error": newAPIError.ToOpenAIError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,35 +128,34 @@ func WssRelay(c *gin.Context) {
|
||||
defer ws.Close()
|
||||
|
||||
if err != nil {
|
||||
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
originalModel := c.GetString("original_model")
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
break
|
||||
}
|
||||
|
||||
openaiErr = wssRequest(c, ws, relayMode, channel)
|
||||
newAPIError = wssRequest(c, ws, relayMode, channel)
|
||||
|
||||
if openaiErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -165,12 +165,12 @@ func WssRelay(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if openaiErr != nil {
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
if newAPIError != nil {
|
||||
if newAPIError.StatusCode == http.StatusTooManyRequests {
|
||||
newAPIError.SetMessage("当前分组上游负载已饱和,请稍后再试")
|
||||
}
|
||||
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
|
||||
helper.WssError(c, ws, openaiErr.Error)
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
helper.WssError(c, ws, newAPIError.ToOpenAIError())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,27 +179,25 @@ func RelayClaude(c *gin.Context) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
var claudeErr *dto.ClaudeErrorWithStatusCode
|
||||
var newAPIError *types.NewAPIError
|
||||
|
||||
for i := 0; i <= common.RetryTimes; i++ {
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
claudeErr = service.ClaudeErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
newAPIError = types.NewError(err, types.ErrorCodeGetChannelFailed)
|
||||
break
|
||||
}
|
||||
|
||||
claudeErr = claudeRequest(c, channel)
|
||||
newAPIError = claudeRequest(c, channel)
|
||||
|
||||
if claudeErr == nil {
|
||||
if newAPIError == nil {
|
||||
return // 成功处理请求,直接返回
|
||||
}
|
||||
|
||||
openaiErr := service.ClaudeErrorToOpenAIError(claudeErr)
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), newAPIError)
|
||||
|
||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||
|
||||
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||
if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -209,30 +207,30 @@ func RelayClaude(c *gin.Context) {
|
||||
common.LogInfo(c, retryLogStr)
|
||||
}
|
||||
|
||||
if claudeErr != nil {
|
||||
claudeErr.Error.Message = common.MessageWithRequestId(claudeErr.Error.Message, requestId)
|
||||
c.JSON(claudeErr.StatusCode, gin.H{
|
||||
if newAPIError != nil {
|
||||
newAPIError.SetMessage(common.MessageWithRequestId(newAPIError.Error(), requestId))
|
||||
c.JSON(newAPIError.StatusCode, gin.H{
|
||||
"type": "error",
|
||||
"error": claudeErr.Error,
|
||||
"error": newAPIError.ToClaudeError(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relayHandler(c, relayMode)
|
||||
}
|
||||
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||
func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return relay.WssHelper(c, ws)
|
||||
}
|
||||
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *dto.ClaudeErrorWithStatusCode {
|
||||
func claudeRequest(c *gin.Context, channel *model.Channel) *types.NewAPIError {
|
||||
addUsedChannel(c, channel.Id)
|
||||
requestBody, _ := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
@@ -259,19 +257,25 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
if group == "auto" {
|
||||
return nil, errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error()))
|
||||
}
|
||||
return nil, errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error()))
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||
func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) bool {
|
||||
if openaiErr == nil {
|
||||
return false
|
||||
}
|
||||
if openaiErr.LocalError {
|
||||
if types.IsChannelError(openaiErr) {
|
||||
return true
|
||||
}
|
||||
if types.IsLocalError(openaiErr) {
|
||||
return false
|
||||
}
|
||||
if retryTimes <= 0 {
|
||||
@@ -295,7 +299,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == common.ChannelTypeAnthropic {
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -310,12 +314,12 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *types.NewAPIError) {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error()))
|
||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||
service.DisableChannel(channelId, channelName, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -388,9 +392,10 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
channel, err := getChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||
break
|
||||
}
|
||||
channelId = channel.Id
|
||||
@@ -398,7 +403,7 @@ func RelayTask(c *gin.Context) {
|
||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||
c.Set("use_channel", useChannel)
|
||||
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||
|
||||
requestBody, err := common.GetRequestBody(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
@@ -420,7 +425,7 @@ func RelayTask(c *gin.Context) {
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
|
||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
}
|
||||
|
||||
138
controller/task_video.go
Normal file
138
controller/task_video.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if taskResult.Code != 0 {
|
||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
//}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
}
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -258,3 +258,32 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
|
||||
func DeleteTokenBatch(c *gin.Context) {
|
||||
tokenBatch := TokenBatch{}
|
||||
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
payType := "wxpay"
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = "alipay"
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = "wxpay"
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
Type: req.PaymentMethod,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
@@ -246,15 +247,15 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
pageInfo, err := common.GetPageQuery(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "parse page query failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
|
||||
users, total, err := model.GetAllUsers(pageInfo)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -262,15 +263,13 @@ func GetAllUsers(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
"data": pageInfo,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -489,7 +488,7 @@ func GetUserModels(c *gin.Context) {
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
var models []string
|
||||
for group := range groups {
|
||||
for _, g := range model.GetGroupModels(group) {
|
||||
for _, g := range model.GetGroupEnabledModels(group) {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
@@ -963,7 +962,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != constant.NotifyTypeEmail && req.QuotaWarningType != constant.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
@@ -981,7 +980,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果是webhook类型,验证webhook地址
|
||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
if req.WebhookUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -1000,7 +999,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 如果是邮件类型,验证邮箱地址
|
||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
// 验证邮箱格式
|
||||
if !strings.Contains(req.NotificationEmail, "@") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -1022,24 +1021,24 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 构建设置
|
||||
settings := map[string]interface{}{
|
||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
||||
constant.UserSettingRecordIpLog: req.RecordIpLog,
|
||||
settings := dto.UserSetting{
|
||||
NotifyType: req.QuotaWarningType,
|
||||
QuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel,
|
||||
RecordIpLog: req.RecordIpLog,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
if req.QuotaWarningType == constant.NotifyTypeWebhook {
|
||||
settings[constant.UserSettingWebhookUrl] = req.WebhookUrl
|
||||
if req.QuotaWarningType == dto.NotifyTypeWebhook {
|
||||
settings.WebhookUrl = req.WebhookUrl
|
||||
if req.WebhookSecret != "" {
|
||||
settings[constant.UserSettingWebhookSecret] = req.WebhookSecret
|
||||
settings.WebhookSecret = req.WebhookSecret
|
||||
}
|
||||
}
|
||||
|
||||
// 如果提供了通知邮箱,添加到设置中
|
||||
if req.QuotaWarningType == constant.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings[constant.UserSettingNotificationEmail] = req.NotificationEmail
|
||||
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
|
||||
@@ -16,7 +16,7 @@ services:
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
|
||||
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
|
||||
7
dto/channel_settings.go
Normal file
7
dto/channel_settings.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package dto
|
||||
|
||||
type ChannelSettings struct {
|
||||
ForceFormat bool `json:"force_format,omitempty"`
|
||||
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
|
||||
Proxy string `json:"proxy"`
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package dto
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
@@ -228,7 +229,7 @@ type ClaudeResponse struct {
|
||||
Completion string `json:"completion,omitempty"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Error *ClaudeError `json:"error,omitempty"`
|
||||
Error *types.ClaudeError `json:"error,omitempty"`
|
||||
Usage *ClaudeUsage `json:"usage,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
|
||||
|
||||
@@ -15,6 +15,7 @@ type ImageRequest struct {
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
||||
12
dto/error.go
12
dto/error.go
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
@@ -14,11 +16,11 @@ type OpenAIErrorWithStatusCode struct {
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Error types.OpenAIError `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
|
||||
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
VideoUrl string `json:"videoUrl"`
|
||||
VideoUrls []ImgUrls `json:"videoUrls"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type ImgUrls struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
@@ -53,9 +53,11 @@ type GeneralOpenAIRequest struct {
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
@@ -63,8 +65,8 @@ type GeneralOpenAIRequest struct {
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.EncodeJson(r)
|
||||
_ = common.DecodeJson(data, &result)
|
||||
data, _ := common.Marshal(r)
|
||||
_ = common.Unmarshal(data, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -644,4 +646,6 @@ type ResponsesToolsCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Function json.RawMessage `json:"function,omitempty"`
|
||||
Container json.RawMessage `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package dto
|
||||
|
||||
import "encoding/json"
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
type SimpleResponse struct {
|
||||
Usage `json:"usage"`
|
||||
@@ -26,9 +29,9 @@ type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
@@ -178,6 +181,8 @@ type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||
// OpenRouter Params
|
||||
Cost float64 `json:"cost,omitempty"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
@@ -199,7 +204,7 @@ type OpenAIResponsesResponse struct {
|
||||
Object string `json:"object"`
|
||||
CreatedAt int `json:"created_at"`
|
||||
Status string `json:"status"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
|
||||
Instructions string `json:"instructions"`
|
||||
MaxOutputTokens int `json:"max_output_tokens"`
|
||||
|
||||
@@ -1,26 +1,11 @@
|
||||
package dto
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
import "one-api/constant"
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
38
dto/ratio_sync.go
Normal file
38
dto/ratio_sync.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package dto
|
||||
|
||||
type UpstreamDTO struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
// Current 为本地值,可能为 nil
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
}
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package dto
|
||||
|
||||
import "one-api/types"
|
||||
|
||||
const (
|
||||
RealtimeEventTypeError = "error"
|
||||
RealtimeEventTypeSessionUpdate = "session.update"
|
||||
@@ -23,12 +25,12 @@ type RealtimeEvent struct {
|
||||
EventId string `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
//PreviousItemId string `json:"previous_item_id"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Session *RealtimeSession `json:"session,omitempty"`
|
||||
Item *RealtimeItem `json:"item,omitempty"`
|
||||
Error *types.OpenAIError `json:"error,omitempty"`
|
||||
Response *RealtimeResponse `json:"response,omitempty"`
|
||||
Delta string `json:"delta,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type RealtimeResponse struct {
|
||||
|
||||
@@ -4,7 +4,7 @@ type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
TopN int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
|
||||
16
dto/user_settings.go
Normal file
16
dto/user_settings.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package dto
|
||||
|
||||
type UserSetting struct {
|
||||
NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型
|
||||
QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值
|
||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
)
|
||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package dto
|
||||
|
||||
type VideoRequest struct {
|
||||
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||
Width int `json:"width" example:"512"` // Video width
|
||||
Height int `json:"height" example:"512"` // Video height
|
||||
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||
}
|
||||
|
||||
// VideoResponse 视频生成提交任务后的响应
|
||||
type VideoResponse struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||
type VideoTaskResponse struct {
|
||||
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||
}
|
||||
|
||||
// VideoTaskMetadata 视频任务元数据
|
||||
type VideoTaskMetadata struct {
|
||||
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||
Width int `json:"width" example:"512"` // 实际宽度
|
||||
Height int `json:"height" example:"512"` // 实际高度
|
||||
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||
}
|
||||
|
||||
// VideoTaskError 视频任务错误信息
|
||||
type VideoTaskError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
1041
i18n/zh-cn.json
Normal file
1041
i18n/zh-cn.json
Normal file
File diff suppressed because it is too large
Load Diff
92
main.go
92
main.go
@@ -12,7 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -32,14 +32,13 @@ var buildFS embed.FS
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load(".env")
|
||||
|
||||
err := InitResources()
|
||||
if err != nil {
|
||||
common.SysLog("Support for .env file is disabled: " + err.Error())
|
||||
common.FatalLog("failed to initialize resources: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
common.LoadEnv()
|
||||
|
||||
common.SetupLogger()
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -47,19 +46,7 @@ func main() {
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("running in debug mode")
|
||||
}
|
||||
// Initialize SQL Database
|
||||
err = model.InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
defer func() {
|
||||
err := model.CloseDB()
|
||||
if err != nil {
|
||||
@@ -67,21 +54,6 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize Redis
|
||||
err = common.InitRedisClient()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
model.InitOptionMap()
|
||||
|
||||
service.InitTokenEncoders()
|
||||
|
||||
if common.RedisEnabled {
|
||||
// for compatibility with old versions
|
||||
common.MemoryCacheEnabled = true
|
||||
@@ -96,9 +68,9 @@ func main() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
|
||||
// Retry once
|
||||
_, fixErr := model.FixAbility()
|
||||
_, _, fixErr := model.FixAbility()
|
||||
if fixErr != nil {
|
||||
common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -186,3 +158,53 @@ func main() {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func InitResources() error {
|
||||
// Initialize resources here if needed
|
||||
// This is a placeholder function for future resource initialization
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||
}
|
||||
|
||||
common.SetupLogger()
|
||||
|
||||
// 加载环境变量
|
||||
common.InitEnv()
|
||||
|
||||
// Initialize model settings
|
||||
ratio_setting.InitRatioSettings()
|
||||
|
||||
service.InitHttpClient()
|
||||
|
||||
service.InitTokenEncoders()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize options, should after model.InitDB()
|
||||
model.InitOptionMap()
|
||||
|
||||
// 初始化模型
|
||||
model.GetPricing()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Redis
|
||||
err = common.InitRedisClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -184,7 +185,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
// gemini api 从query中获取key
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
skKey := c.Query("key")
|
||||
if skKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||
@@ -233,30 +234,41 @@ func TokenAuth() func(c *gin.Context) {
|
||||
|
||||
userCache.WriteContext(c)
|
||||
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
}
|
||||
err = SetupContextForToken(c, token, parts...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func SetupContextForToken(c *gin.Context, token *model.Token, parts ...string) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token is nil")
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set("token_unlimited_quota", token.UnlimitedQuota)
|
||||
if !token.UnlimitedQuota {
|
||||
c.Set("token_quota", token.RemainQuota)
|
||||
}
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
c.Set("allow_ips", token.GetIpLimitsMap())
|
||||
c.Set("token_group", token.Group)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
} else {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return fmt.Errorf("普通用户不支持指定渠道")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,11 +21,12 @@ import (
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
allowIpsMap := c.GetStringMap("allow_ips")
|
||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
@@ -33,14 +35,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
@@ -48,7 +50,7 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
@@ -56,7 +58,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set("group", userGroup)
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -75,9 +77,9 @@ func Distribute() func(c *gin.Context) {
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -120,7 +122,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
@@ -169,7 +171,26 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
var platform string
|
||||
var relayMode int
|
||||
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
||||
platform = string(constant.TaskPlatformJimeng)
|
||||
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
} else {
|
||||
platform = string(constant.TaskPlatformKling)
|
||||
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
c.Set("platform", platform)
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
relayMode := relayconstant.RelayModeGemini
|
||||
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
||||
@@ -217,6 +238,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") {
|
||||
// playground chat completions
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return nil, false, errors.New("无效的请求, " + err.Error())
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group)
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
@@ -225,37 +254,42 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
if channel == nil {
|
||||
return
|
||||
}
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("channel_type", channel.Type)
|
||||
c.Set("channel_create_time", channel.CreatedTime)
|
||||
c.Set("channel_setting", channel.GetSetting())
|
||||
c.Set("param_override", channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime)
|
||||
common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelParamOverride, channel.GetParamOverride())
|
||||
if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization)
|
||||
}
|
||||
common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping())
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
|
||||
|
||||
}
|
||||
c.Set("auto_ban", channel.GetAutoBan())
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL())
|
||||
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeVertexAi:
|
||||
case constant.ChannelTypeVertexAi:
|
||||
c.Set("region", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
case constant.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
case constant.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeMokaAI:
|
||||
case constant.ChannelTypeMokaAI:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeCoze:
|
||||
case constant.ChannelTypeCoze:
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
47
middleware/kling_adapter.go
Normal file
47
middleware/kling_adapter.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func KlingRequestConvert() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
var originalReq map[string]interface{}
|
||||
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := originalReq["model"].(string)
|
||||
prompt, _ := originalReq["prompt"].(string)
|
||||
|
||||
unifiedReq := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"metadata": originalReq,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(unifiedReq)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite request body and path
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
if image := originalReq["image"]; image == "" {
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// We have to reset the request body for the next handlers
|
||||
c.Set(common.KeyRequestBody, jsonData)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 获取分组
|
||||
group := c.GetString("token_group")
|
||||
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if group == "" {
|
||||
group = c.GetString(constant.ContextKeyUserGroup)
|
||||
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
}
|
||||
|
||||
//获取分组的限流配置
|
||||
|
||||
119
model/ability.go
119
model/ability.go
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
@@ -21,7 +22,22 @@ type Ability struct {
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
type AbilityWithChannel struct {
|
||||
Ability
|
||||
ChannelType int `json:"channel_type"`
|
||||
}
|
||||
|
||||
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||
var abilities []AbilityWithChannel
|
||||
err := DB.Table("abilities").
|
||||
Select("abilities.*, channels.type as channel_type").
|
||||
Joins("left join channels on abilities.channel_id = channels.id").
|
||||
Where("abilities.enabled = ?", true).
|
||||
Scan(&abilities).Error
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
func GetGroupEnabledModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
@@ -46,7 +62,7 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
@@ -72,14 +88,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
|
||||
return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
|
||||
}
|
||||
|
||||
func FixAbility() (int, error) {
|
||||
var channelIds []int
|
||||
count := 0
|
||||
// Find all channel ids from channel table
|
||||
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
|
||||
var fixLock = sync.Mutex{}
|
||||
|
||||
func FixAbility() (int, int, error) {
|
||||
lock := fixLock.TryLock()
|
||||
if !lock {
|
||||
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
||||
}
|
||||
defer fixLock.Unlock()
|
||||
var channels []*Channel
|
||||
// Find all channels
|
||||
err := DB.Model(&Channel{}).Find(&channels).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
|
||||
if len(channelIds) > 0 {
|
||||
// Process deletion in chunks to avoid "too many placeholders" error
|
||||
for _, chunk := range lo.Chunk(channelIds, 100) {
|
||||
err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no channels exist, delete all abilities
|
||||
err = DB.Delete(&Ability{}).Error
|
||||
if len(channels) == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
successCount := 0
|
||||
failCount := 0
|
||||
for _, chunk := range lo.Chunk(channels, 50) {
|
||||
ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
|
||||
// Delete all abilities of this channel
|
||||
err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
|
||||
return 0, err
|
||||
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
||||
failCount += len(chunk)
|
||||
continue
|
||||
}
|
||||
common.SysLog("Delete all abilities successfully")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
|
||||
count += len(channelIds)
|
||||
|
||||
// Use channelIds to find channel not in abilities table
|
||||
var abilityChannelIds []int
|
||||
err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
}
|
||||
|
||||
var channels []Channel
|
||||
if len(abilityChannelIds) == 0 {
|
||||
err = DB.Find(&channels).Error
|
||||
} else {
|
||||
// Process query in chunks to avoid "too many placeholders" error
|
||||
err = nil
|
||||
for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
|
||||
var channelsChunk []Channel
|
||||
err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
|
||||
// Then add new abilities
|
||||
for _, channel := range chunk {
|
||||
err = channel.AddAbilities()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
|
||||
return count, err
|
||||
common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
||||
failCount++
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
channels = append(channels, channelsChunk...)
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
err := channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
|
||||
count++
|
||||
}
|
||||
}
|
||||
InitChannelCache()
|
||||
return count, nil
|
||||
return successCount, failCount, nil
|
||||
}
|
||||
|
||||
152
model/channel.go
152
model/channel.go
@@ -1,8 +1,13 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -35,8 +40,100 @@ type Channel struct {
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
}
|
||||
|
||||
type ChannelInfo struct {
|
||||
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
|
||||
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status
|
||||
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
|
||||
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (c ChannelInfo) Value() (driver.Value, error) {
|
||||
return common.Marshal(&c)
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (c *ChannelInfo) Scan(value interface{}) error {
|
||||
bytesValue, _ := value.([]byte)
|
||||
return common.Unmarshal(bytesValue, c)
|
||||
}
|
||||
|
||||
func (channel *Channel) getKeys() []string {
|
||||
if channel.Key == "" {
|
||||
return []string{}
|
||||
}
|
||||
// use \n to split keys
|
||||
keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n")
|
||||
return keys
|
||||
}
|
||||
|
||||
func (channel *Channel) GetNextEnabledKey() (string, error) {
|
||||
// If not in multi-key mode, return the original key string directly.
|
||||
if !channel.ChannelInfo.IsMultiKey {
|
||||
return channel.Key, nil
|
||||
}
|
||||
|
||||
// Obtain all keys (split by \n)
|
||||
keys := channel.getKeys()
|
||||
if len(keys) == 0 {
|
||||
// No keys available, return error, should disable the channel
|
||||
return "", fmt.Errorf("no valid keys in channel")
|
||||
}
|
||||
|
||||
statusList := channel.ChannelInfo.MultiKeyStatusList
|
||||
// helper to get key status, default to enabled when missing
|
||||
getStatus := func(idx int) int {
|
||||
if statusList == nil {
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
if status, ok := statusList[idx]; ok {
|
||||
return status
|
||||
}
|
||||
return common.ChannelStatusEnabled
|
||||
}
|
||||
|
||||
// Collect indexes of enabled keys
|
||||
enabledIdx := make([]int, 0, len(keys))
|
||||
for i := range keys {
|
||||
if getStatus(i) == common.ChannelStatusEnabled {
|
||||
enabledIdx = append(enabledIdx, i)
|
||||
}
|
||||
}
|
||||
// If no specific status list or none enabled, fall back to first key
|
||||
if len(enabledIdx) == 0 {
|
||||
return keys[0], nil
|
||||
}
|
||||
|
||||
switch channel.ChannelInfo.MultiKeyMode {
|
||||
case constant.MultiKeyModeRandom:
|
||||
// Randomly pick one enabled key
|
||||
return keys[enabledIdx[rand.Intn(len(enabledIdx))]], nil
|
||||
case constant.MultiKeyModePolling:
|
||||
// Start from the saved polling index and look for the next enabled key
|
||||
start := channel.ChannelInfo.MultiKeyPollingIndex
|
||||
if start < 0 || start >= len(keys) {
|
||||
start = 0
|
||||
}
|
||||
for i := 0; i < len(keys); i++ {
|
||||
idx := (start + i) % len(keys)
|
||||
if getStatus(idx) == common.ChannelStatusEnabled {
|
||||
// update polling index for next call (point to the next position)
|
||||
channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys)
|
||||
return keys[idx], nil
|
||||
}
|
||||
}
|
||||
// Fallback – should not happen, but return first enabled key
|
||||
return keys[enabledIdx[0]], nil
|
||||
default:
|
||||
// Unknown mode, default to first enabled key (or original key string)
|
||||
return keys[enabledIdx[0]], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (channel *Channel) GetModels() []string {
|
||||
@@ -514,8 +611,19 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
setting := make(map[string]interface{})
|
||||
func (channel *Channel) ValidateSettings() error {
|
||||
channelParams := &dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) GetSetting() dto.ChannelSettings {
|
||||
setting := dto.ChannelSettings{}
|
||||
if channel.Setting != nil && *channel.Setting != "" {
|
||||
err := json.Unmarshal([]byte(*channel.Setting), &setting)
|
||||
if err != nil {
|
||||
@@ -525,7 +633,7 @@ func (channel *Channel) GetSetting() map[string]interface{} {
|
||||
return setting
|
||||
}
|
||||
|
||||
func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
@@ -597,3 +705,39 @@ func CountAllTags() (int64, error) {
|
||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Get channels of specified type with pagination
|
||||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// Count channels of specific type
|
||||
func CountChannelsByType(channelType int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Return map[type]count for all channels
|
||||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||
type result struct {
|
||||
Type int64 `gorm:"column:type"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.Type] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
63
model/log.go
63
model/log.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -50,7 +49,7 @@ func formatUserLogs(logs []*Log) {
|
||||
for i := range logs {
|
||||
logs[i].ChannelName = ""
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
otherMap, _ = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
@@ -100,10 +99,8 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
if settingMap.RecordIpLog {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
@@ -136,22 +133,34 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
type RecordConsumeLogParams struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ModelName string `json:"model_name"`
|
||||
TokenName string `json:"token_name"`
|
||||
Quota int `json:"quota"`
|
||||
Content string `json:"content"`
|
||||
TokenId int `json:"token_id"`
|
||||
UserQuota int `json:"user_quota"`
|
||||
UseTimeSeconds int `json:"use_time_seconds"`
|
||||
IsStream bool `json:"is_stream"`
|
||||
Group string `json:"group"`
|
||||
Other map[string]interface{} `json:"other"`
|
||||
}
|
||||
|
||||
func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
|
||||
common.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
otherStr := common.MapToJsonStr(params.Other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
if settingMap.RecordIpLog {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
@@ -159,17 +168,17 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
Username: username,
|
||||
CreatedAt: common.GetTimestamp(),
|
||||
Type: LogTypeConsume,
|
||||
Content: content,
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TokenName: tokenName,
|
||||
ModelName: modelName,
|
||||
Quota: quota,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Content: params.Content,
|
||||
PromptTokens: params.PromptTokens,
|
||||
CompletionTokens: params.CompletionTokens,
|
||||
TokenName: params.TokenName,
|
||||
ModelName: params.ModelName,
|
||||
Quota: params.Quota,
|
||||
ChannelId: params.ChannelId,
|
||||
TokenId: params.TokenId,
|
||||
UseTime: params.UseTimeSeconds,
|
||||
IsStream: params.IsStream,
|
||||
Group: params.Group,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
@@ -184,7 +193,7 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
gopool.Go(func() {
|
||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens)
|
||||
LogQuotaData(userId, username, params.ModelName, params.Quota, common.GetTimestamp(), params.PromptTokens+params.CompletionTokens)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,18 @@ func initCol() {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
} else {
|
||||
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
||||
if common.UsingPostgreSQL {
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
} else {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
@@ -216,12 +225,6 @@ func InitLogDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
//if common.UsingMySQL {
|
||||
// _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded
|
||||
// _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded
|
||||
//}
|
||||
common.SysLog("database migration started")
|
||||
err = migrateLOGDB()
|
||||
return err
|
||||
|
||||
@@ -14,6 +14,8 @@ type Midjourney struct {
|
||||
StartTime int64 `json:"start_time" gorm:"index"`
|
||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
VideoUrls string `json:"video_urls"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -78,6 +79,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -96,13 +98,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -125,6 +127,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||
@@ -265,6 +268,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
ratio_setting.SetExposeRatioEnabled(boolValue)
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -358,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DataExportDefaultTime":
|
||||
common.DataExportDefaultTime = value
|
||||
case "ModelRatio":
|
||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupRatioByJSONString(value)
|
||||
case "GroupGroupRatio":
|
||||
err = setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
@@ -387,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
118
model/pricing.go
118
model/pricing.go
@@ -1,20 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/constant"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Pricing struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -23,56 +27,98 @@ var (
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
var (
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
modelSupportEndpointsLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
updatePricing()
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
// Double check after acquiring the lock
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
modelSupportEndpointsLock.Lock()
|
||||
defer modelSupportEndpointsLock.Unlock()
|
||||
updatePricing()
|
||||
}
|
||||
}
|
||||
//if group != "" {
|
||||
// userPricingMap := make([]Pricing, 0)
|
||||
// models := GetGroupModels(group)
|
||||
// for _, pricing := range pricingMap {
|
||||
// if !common.StringsContains(models, pricing.ModelName) {
|
||||
// pricing.Available = false
|
||||
// }
|
||||
// userPricingMap = append(userPricingMap, pricing)
|
||||
// }
|
||||
// return userPricingMap
|
||||
//}
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||
if model == "" {
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
modelSupportEndpointsLock.RLock()
|
||||
defer modelSupportEndpointsLock.RUnlock()
|
||||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||||
return endpoints
|
||||
}
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
|
||||
func updatePricing() {
|
||||
//modelRatios := common.GetModelRatios()
|
||||
enableAbilities := GetAllEnableAbilities()
|
||||
modelGroupsMap := make(map[string][]string)
|
||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
return
|
||||
}
|
||||
modelGroupsMap := make(map[string]*types.Set[string])
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
groups := modelGroupsMap[ability.Model]
|
||||
if groups == nil {
|
||||
groups = make([]string, 0)
|
||||
groups, ok := modelGroupsMap[ability.Model]
|
||||
if !ok {
|
||||
groups = types.NewSet[string]()
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
}
|
||||
if !common.StringsContains(groups, ability.Group) {
|
||||
groups = append(groups, ability.Group)
|
||||
groups.Add(ability.Group)
|
||||
}
|
||||
|
||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||
modelSupportEndpointsStr := make(map[string][]string)
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||
if !ok {
|
||||
endpoints = make([]string, 0)
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||
for _, channelType := range channelTypes {
|
||||
if !common.StringsContains(endpoints, string(channelType)) {
|
||||
endpoints = append(endpoints, string(channelType))
|
||||
}
|
||||
}
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
for model, endpoints := range modelSupportEndpointsStr {
|
||||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||
for _, endpointStr := range endpoints {
|
||||
endpointType := constant.EndpointType(endpointStr)
|
||||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||
}
|
||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||
}
|
||||
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, groups := range modelGroupsMap {
|
||||
pricing := Pricing{
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
ModelName: model,
|
||||
EnableGroup: groups.Items(),
|
||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||
}
|
||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
|
||||
@@ -327,3 +327,37 @@ func CountUserTokens(userId int) (int64, error) {
|
||||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
|
||||
func BatchDeleteTokens(ids []int, userId int) (int, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, errors.New("ids 不能为空!")
|
||||
}
|
||||
|
||||
tx := DB.Begin()
|
||||
|
||||
var tokens []Token
|
||||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
for _, t := range tokens {
|
||||
_ = cacheDeleteToken(t.Key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -68,14 +69,18 @@ func (user *User) SetAccessToken(token string) {
|
||||
user.AccessToken = &token
|
||||
}
|
||||
|
||||
func (user *User) GetSetting() map[string]interface{} {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
func (user *User) GetSetting() dto.UserSetting {
|
||||
setting := dto.UserSetting{}
|
||||
if user.Setting != "" {
|
||||
err := json.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
}
|
||||
return common.StrToMap(user.Setting)
|
||||
return setting
|
||||
}
|
||||
|
||||
func (user *User) SetSetting(setting map[string]interface{}) {
|
||||
func (user *User) SetSetting(setting dto.UserSetting) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
@@ -114,7 +119,7 @@ func GetMaxUserId() int {
|
||||
return user.Id
|
||||
}
|
||||
|
||||
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
|
||||
func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
|
||||
// Start transaction
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
@@ -134,7 +139,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error)
|
||||
}
|
||||
|
||||
// Get paginated users within same transaction
|
||||
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
||||
err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, 0, err
|
||||
@@ -626,7 +631,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
}
|
||||
|
||||
// GetUserSetting gets setting from Redis first, falls back to DB if needed
|
||||
func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err error) {
|
||||
func GetUserSetting(id int, fromDB bool) (settingMap dto.UserSetting, err error) {
|
||||
var setting string
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
@@ -648,10 +653,12 @@ func GetUserSetting(id int, fromDB bool) (settingMap map[string]interface{}, err
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("setting").Find(&setting).Error
|
||||
if err != nil {
|
||||
return map[string]interface{}{}, err
|
||||
return settingMap, err
|
||||
}
|
||||
|
||||
return common.StrToMap(setting), nil
|
||||
userBase := &UserBase{
|
||||
Setting: setting,
|
||||
}
|
||||
return userBase.GetSetting(), nil
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int, db bool) (err error) {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -24,28 +24,23 @@ type UserBase struct {
|
||||
}
|
||||
|
||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
||||
c.Set("username", user.Username)
|
||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
|
||||
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
|
||||
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
|
||||
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
|
||||
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
|
||||
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||
}
|
||||
|
||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
if user.Setting == "" {
|
||||
return nil
|
||||
func (user *UserBase) GetSetting() dto.UserSetting {
|
||||
setting := dto.UserSetting{}
|
||||
if user.Setting != "" {
|
||||
err := common.Unmarshal([]byte(user.Setting), &setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to unmarshal setting: " + err.Error())
|
||||
}
|
||||
}
|
||||
return common.StrToMap(user.Setting)
|
||||
}
|
||||
|
||||
func (user *UserBase) SetSetting(setting map[string]interface{}) {
|
||||
settingBytes, err := json.Marshal(setting)
|
||||
if err != nil {
|
||||
common.SysError("failed to marshal setting: " + err.Error())
|
||||
return
|
||||
}
|
||||
user.Setting = string(settingBytes)
|
||||
return setting
|
||||
}
|
||||
|
||||
// getUserCacheKey returns the key for user cache
|
||||
@@ -70,7 +65,7 @@ func updateUserCache(user User) error {
|
||||
return common.RedisHSetObj(
|
||||
getUserCacheKey(user.Id),
|
||||
user.ToBaseUser(),
|
||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
||||
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -174,11 +169,10 @@ func getUserNameCache(userId int) (string, error) {
|
||||
return cache.Username, nil
|
||||
}
|
||||
|
||||
func getUserSettingCache(userId int) (map[string]interface{}, error) {
|
||||
setting := make(map[string]interface{})
|
||||
func getUserSettingCache(userId int) (dto.UserSetting, error) {
|
||||
cache, err := GetUserCache(userId)
|
||||
if err != nil {
|
||||
return setting, err
|
||||
return dto.UserSetting{}, err
|
||||
}
|
||||
return cache.GetSetting(), nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
// check if there's any data to update
|
||||
hasData := false
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
if len(batchUpdateStores[i]) > 0 {
|
||||
hasData = true
|
||||
batchUpdateLocks[i].Unlock()
|
||||
break
|
||||
}
|
||||
batchUpdateLocks[i].Unlock()
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
|
||||
@@ -3,7 +3,6 @@ package relay
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -12,7 +11,10 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
@@ -54,29 +56,26 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
return audioRequest, nil
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest)
|
||||
return types.NewError(err, types.ErrorCodeInvalidRequest)
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
preConsumedTokens = promptTokens
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeModelPriceError)
|
||||
}
|
||||
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
@@ -89,27 +88,25 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
|
||||
}
|
||||
|
||||
audioRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
|
||||
}
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeDoRequestFailed)
|
||||
}
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -117,18 +114,18 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if openaiErr != nil {
|
||||
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
|
||||
if newAPIError != nil {
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ type Adaptor interface {
|
||||
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
|
||||
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
|
||||
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode)
|
||||
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
|
||||
GetModelList() []string
|
||||
GetChannelName() string
|
||||
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
|
||||
@@ -44,4 +45,6 @@ type TaskAdaptor interface {
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -30,7 +31,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
@@ -82,7 +83,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return embeddingRequestOpenAI2Ali(request), nil
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
@@ -99,7 +100,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
@@ -109,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -4,15 +4,17 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
|
||||
@@ -124,52 +126,46 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relayc
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
if aliTaskResponse.Message != "" {
|
||||
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
|
||||
return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, nil
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -30,32 +31,26 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
usage := dto.Usage{
|
||||
@@ -70,14 +65,10 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -8,9 +8,10 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -38,42 +39,26 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
model := c.GetString("model")
|
||||
if model == "" {
|
||||
model = "text-embedding-v4"
|
||||
}
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
@@ -135,7 +120,7 @@ func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStre
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
@@ -186,42 +171,33 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: "ali_error",
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -206,8 +206,8 @@ func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
|
||||
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||
if info.ChannelSetting.Proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -84,7 +85,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -3,19 +3,22 @@ package aws
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/claude"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||
)
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
@@ -65,24 +68,21 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) (string, error) {
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID, nil
|
||||
return awsModelID
|
||||
}
|
||||
|
||||
return requestModel, nil
|
||||
return requestModel
|
||||
}
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -98,42 +98,42 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
awsCli, err := newAwsClient(c, info)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||
}
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -149,25 +149,25 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
claudeReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return wrapErr(errors.New("request not found")), nil
|
||||
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
claudeReq := claudeReq_.(*dto.ClaudeRequest)
|
||||
|
||||
awsClaudeReq := copyRequest(claudeReq)
|
||||
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
stream := awsResp.GetStream()
|
||||
defer stream.Close()
|
||||
|
||||
claudeInfo := &claude.ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
@@ -176,18 +176,18 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
|
||||
for event := range stream.Events() {
|
||||
switch v := event.(type) {
|
||||
case *types.ResponseStreamMemberChunk:
|
||||
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
|
||||
info.SetFirstResponseTime()
|
||||
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
|
||||
if respErr != nil {
|
||||
return respErr, nil
|
||||
}
|
||||
case *types.UnknownUnionMember:
|
||||
case *bedrockruntimeTypes.UnknownUnionMember:
|
||||
fmt.Println("unknown tag:", v.Tag)
|
||||
return wrapErr(errors.New("unknown response type")), nil
|
||||
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||
default:
|
||||
fmt.Println("union is nil or unknown type")
|
||||
return wrapErr(errors.New("nil or unknown response type")), nil
|
||||
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -140,15 +141,15 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = baiduStreamHandler(c, resp)
|
||||
err, usage = baiduStreamHandler(c, info, resp)
|
||||
} else {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
err, usage = baiduEmbeddingHandler(c, info, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
err, usage = baiduHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
|
||||
@@ -110,98 +112,49 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAI
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var usage dto.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
helper.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
usage := &dto.Usage{}
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
err := common.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
if baiduResponse.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = baiduResponse.Usage.TotalTokens
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
response := streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
err = helper.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.SysError("error sending stream response: " + err.Error())
|
||||
}
|
||||
return true
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var baiduResponse BaiduChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -209,35 +162,24 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
Code: baiduResponse.ErrorCode,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -280,7 +222,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := service.GetImpatientHttpClient().Do(req)
|
||||
res, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -42,7 +43,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 0 || keyParts[0] == "" {
|
||||
return errors.New("invalid API key: authorization token is required")
|
||||
}
|
||||
if len(keyParts) > 1 {
|
||||
if keyParts[1] != "" {
|
||||
req.Set("appid", keyParts[1])
|
||||
}
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+keyParts[0])
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,11 +93,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -94,7 +95,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||
} else {
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -122,6 +124,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
}
|
||||
|
||||
if textRequest.Reasoning != nil {
|
||||
var reasoning openrouter.RequestReasoning
|
||||
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
budgetTokens := reasoning.MaxTokens
|
||||
if budgetTokens > 0 {
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: &budgetTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.Stop != nil {
|
||||
// stop maybe string/array string, convert to array string
|
||||
switch textRequest.Stop.(type) {
|
||||
@@ -501,22 +518,15 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
return true
|
||||
}
|
||||
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.DecodeJsonStr(data, &claudeResponse)
|
||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Code: "stream_response_error",
|
||||
Type: claudeResponse.Error.Type,
|
||||
Message: claudeResponse.Error.Message,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
@@ -549,7 +559,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
@@ -558,7 +568,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -577,15 +587,15 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
}
|
||||
}
|
||||
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
var err *dto.OpenAIErrorWithStatusCode
|
||||
var err *types.NewAPIError
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
|
||||
if err != nil {
|
||||
@@ -601,27 +611,17 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.DecodeJson(data, &claudeResponse)
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Code: claudeResponse.Error.Type,
|
||||
},
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
}
|
||||
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
||||
}
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
@@ -639,20 +639,21 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
openaiResponse.Usage = *claudeInfo.Usage
|
||||
responseData, err = json.Marshal(openaiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
responseData = data
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, err = c.Writer.Write(responseData)
|
||||
|
||||
common.IOCopyBytesGracefully(c, nil, responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
ResponseId: helper.GetResponseID(c),
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
ResponseText: strings.Builder{},
|
||||
@@ -660,9 +661,8 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -94,20 +95,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fallthrough
|
||||
case constant.RelayModeChatCompletions:
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
err, usage = cfStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
err, usage = cfHandler(c, info, resp)
|
||||
}
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err, usage = cfSTTHandler(c, resp, info)
|
||||
err, usage = cfSTTHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package cloudflare
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,8 +10,11 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||
@@ -25,7 +27,7 @@ func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfReque
|
||||
}
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
@@ -71,7 +73,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -81,39 +83,33 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -121,19 +117,16 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
var cfResp CfAudioResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &cfResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
audioResp := &dto.AudioResponse{
|
||||
@@ -142,7 +135,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
jsonResponse, err := json.Marshal(audioResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -150,7 +143,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -71,14 +72,14 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.RelayMode == constant.RelayModeRerank {
|
||||
err, usage = cohereRerankHandler(c, resp, info)
|
||||
} else {
|
||||
if info.IsStream {
|
||||
err, usage = cohereStreamHandler(c, resp, info)
|
||||
err, usage = cohereStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||
err, usage = cohereHandler(c, info, resp)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -3,7 +3,6 @@ package cohere
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -11,8 +10,11 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||
@@ -76,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
@@ -162,25 +164,22 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
createdTime := common.GetTimestamp()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||
@@ -191,7 +190,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
openaiResp.Id = cohereResp.ResponseId
|
||||
openaiResp.Created = createdTime
|
||||
openaiResp.Object = "chat.completion"
|
||||
openaiResp.Model = modelName
|
||||
openaiResp.Model = info.UpstreamModelName
|
||||
openaiResp.Usage = usage
|
||||
|
||||
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||
@@ -204,7 +203,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
|
||||
jsonResponse, err := json.Marshal(openaiResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -212,19 +211,16 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
usage := dto.Usage{}
|
||||
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
||||
@@ -243,7 +239,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
|
||||
jsonResponse, err := json.Marshal(rerankResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/common"
|
||||
"one-api/types"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -95,11 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
|
||||
}
|
||||
|
||||
// DoResponse implements channel.Adaptor.
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||
if info.IsStream {
|
||||
err, usage = cozeChatStreamHandler(c, resp, info)
|
||||
err, usage = cozeChatStreamHandler(c, info, resp)
|
||||
} else {
|
||||
err, usage = cozeChatHandler(c, resp, info)
|
||||
err, usage = cozeChatHandler(c, info, resp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -43,25 +44,22 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
|
||||
return cozeRequest
|
||||
}
|
||||
|
||||
func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
// convert coze response to openai response
|
||||
var response dto.TextResponse
|
||||
var cozeResponse CozeChatDetailResponse
|
||||
response.Model = info.UpstreamModelName
|
||||
err = json.Unmarshal(responseBody, &cozeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
if cozeResponse.Code != 0 {
|
||||
return service.OpenAIErrorWrapper(errors.New(cozeResponse.Msg), fmt.Sprintf("%d", cozeResponse.Code), http.StatusInternalServerError), nil
|
||||
return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
// 从上下文获取 usage
|
||||
var usage dto.Usage
|
||||
@@ -88,7 +86,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -97,7 +95,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
helper.SetEventStreamHeaders(c)
|
||||
@@ -106,7 +104,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
var currentEvent string
|
||||
var currentData string
|
||||
var usage dto.Usage
|
||||
var usage = &dto.Usage{}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -114,7 +112,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
if line == "" {
|
||||
if currentEvent != "" && currentData != "" {
|
||||
// handle last event
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
currentEvent = ""
|
||||
currentData = ""
|
||||
}
|
||||
@@ -134,21 +132,19 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
// Last event
|
||||
if currentEvent != "" && currentData != "" {
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||
@@ -283,8 +279,8 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
|
||||
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
|
||||
var client *http.Client
|
||||
var err error // 声明 err 变量
|
||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||
if info.ChannelSetting.Proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user