mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-15 04:07:26 +00:00
Compare commits
45 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5eba2f1d61 | ||
|
|
5ec421d8e6 | ||
|
|
1e25bf700d | ||
|
|
30fb349d91 | ||
|
|
d40fb68500 | ||
|
|
3049ad47e5 | ||
|
|
8945a3a2dd | ||
|
|
d191eef657 | ||
|
|
6ac7878863 | ||
|
|
c0a23ffa62 | ||
|
|
7d691f362d | ||
|
|
bf577b8937 | ||
|
|
819290c9b8 | ||
|
|
22e8b46159 | ||
|
|
76b8cc1168 | ||
|
|
fce07325b9 | ||
|
|
123862d41c | ||
|
|
7e298f8ad1 | ||
|
|
34aca14858 | ||
|
|
6b1f94348a | ||
|
|
4322037639 | ||
|
|
ae11f88595 | ||
|
|
389a4c3e4c | ||
|
|
efb691e6c2 | ||
|
|
53e3b35437 | ||
|
|
eb265a55e1 | ||
|
|
950f7d214f | ||
|
|
6bd2316d9c | ||
|
|
660180ea1b | ||
|
|
efc8457770 | ||
|
|
9b8b982d8a | ||
|
|
e6949e611a | ||
|
|
cffade7210 | ||
|
|
6b9237f868 | ||
|
|
1f4cf07b63 | ||
|
|
59a1f4c900 | ||
|
|
0a04a76c71 | ||
|
|
9e6bc518cc | ||
|
|
bfb6fbbac9 | ||
|
|
9c08d8cf20 | ||
|
|
281054ff4c | ||
|
|
3002659f47 | ||
|
|
d389befc9e | ||
|
|
c6c68da0b5 | ||
|
|
76da067d40 |
12
.env.example
12
.env.example
@@ -7,6 +7,8 @@
|
||||
# 调试相关配置
|
||||
# 启用pprof
|
||||
# ENABLE_PPROF=true
|
||||
# 启用调试模式
|
||||
# DEBUG=true
|
||||
|
||||
# 数据库相关配置
|
||||
# 数据库连接字符串
|
||||
@@ -41,6 +43,14 @@
|
||||
# 更新任务启用
|
||||
# UPDATE_TASK=true
|
||||
|
||||
# 对话超时设置
|
||||
# 所有请求超时时间,单位秒,默认为0,表示不限制
|
||||
# RELAY_TIMEOUT=0
|
||||
# 流模式无响应超时时间,单位秒,如果出现空补全可以尝试改为更大值
|
||||
# STREAMING_TIMEOUT=120
|
||||
|
||||
# Gemini 识别图片 最大图片数量
|
||||
# GEMINI_VISION_MAX_IMAGE_NUM=16
|
||||
|
||||
# 会话密钥
|
||||
# SESSION_SECRET=random_string
|
||||
@@ -58,8 +68,6 @@
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
# 设置流式一次回复的超时时间
|
||||
# STREAMING_TIMEOUT=120
|
||||
|
||||
|
||||
# 节点类型
|
||||
|
||||
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
19
.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
### PR 类型
|
||||
|
||||
- [ ] Bug 修复
|
||||
- [ ] 新功能
|
||||
- [ ] 文档更新
|
||||
- [ ] 其他
|
||||
|
||||
### PR 是否包含破坏性更新?
|
||||
|
||||
- [ ] 是
|
||||
- [ ] 否
|
||||
|
||||
### PR 描述
|
||||
|
||||
**请在下方详细描述您的 PR,包括目的、实现细节等。**
|
||||
|
||||
### **重要提示**
|
||||
|
||||
**所有 PR 都必须提交到 `alpha` 分支。请确保您的 PR 目标分支是 `alpha`。**
|
||||
1
.github/workflows/macos-release.yml
vendored
1
.github/workflows/macos-release.yml
vendored
@@ -26,6 +26,7 @@ jobs:
|
||||
- name: Build Frontend
|
||||
env:
|
||||
CI: ""
|
||||
NODE_OPTIONS: "--max-old-space-size=4096"
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
|
||||
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
21
.github/workflows/pr-target-branch-check.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
@@ -27,9 +27,6 @@
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
<a href="https://coderabbit.ai">
|
||||
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
71
common/api_type.go
Normal file
71
common/api_type.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case constant.ChannelTypeOpenAI:
|
||||
apiType = constant.APITypeOpenAI
|
||||
case constant.ChannelTypeAnthropic:
|
||||
apiType = constant.APITypeAnthropic
|
||||
case constant.ChannelTypeBaidu:
|
||||
apiType = constant.APITypeBaidu
|
||||
case constant.ChannelTypePaLM:
|
||||
apiType = constant.APITypePaLM
|
||||
case constant.ChannelTypeZhipu:
|
||||
apiType = constant.APITypeZhipu
|
||||
case constant.ChannelTypeAli:
|
||||
apiType = constant.APITypeAli
|
||||
case constant.ChannelTypeXunfei:
|
||||
apiType = constant.APITypeXunfei
|
||||
case constant.ChannelTypeAIProxyLibrary:
|
||||
apiType = constant.APITypeAIProxyLibrary
|
||||
case constant.ChannelTypeTencent:
|
||||
apiType = constant.APITypeTencent
|
||||
case constant.ChannelTypeGemini:
|
||||
apiType = constant.APITypeGemini
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
apiType = constant.APITypeZhipuV4
|
||||
case constant.ChannelTypeOllama:
|
||||
apiType = constant.APITypeOllama
|
||||
case constant.ChannelTypePerplexity:
|
||||
apiType = constant.APITypePerplexity
|
||||
case constant.ChannelTypeAws:
|
||||
apiType = constant.APITypeAws
|
||||
case constant.ChannelTypeCohere:
|
||||
apiType = constant.APITypeCohere
|
||||
case constant.ChannelTypeDify:
|
||||
apiType = constant.APITypeDify
|
||||
case constant.ChannelTypeJina:
|
||||
apiType = constant.APITypeJina
|
||||
case constant.ChannelCloudflare:
|
||||
apiType = constant.APITypeCloudflare
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
apiType = constant.APITypeSiliconFlow
|
||||
case constant.ChannelTypeVertexAi:
|
||||
apiType = constant.APITypeVertexAi
|
||||
case constant.ChannelTypeMistral:
|
||||
apiType = constant.APITypeMistral
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
apiType = constant.APITypeDeepSeek
|
||||
case constant.ChannelTypeMokaAI:
|
||||
apiType = constant.APITypeMokaAI
|
||||
case constant.ChannelTypeVolcEngine:
|
||||
apiType = constant.APITypeVolcEngine
|
||||
case constant.ChannelTypeBaiduV2:
|
||||
apiType = constant.APITypeBaiduV2
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
apiType = constant.APITypeOpenRouter
|
||||
case constant.ChannelTypeXinference:
|
||||
apiType = constant.APITypeXinference
|
||||
case constant.ChannelTypeXai:
|
||||
apiType = constant.APITypeXai
|
||||
case constant.ChannelTypeCoze:
|
||||
apiType = constant.APITypeCoze
|
||||
}
|
||||
if apiType == -1 {
|
||||
return constant.APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
@@ -193,111 +193,3 @@ const (
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
|
||||
41
common/endpoint_type.go
Normal file
41
common/endpoint_type.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package common
|
||||
|
||||
import "one-api/constant"
|
||||
|
||||
// GetEndpointTypesByChannelType 获取渠道最优先端点类型(所有的渠道都支持 OpenAI 端点)
|
||||
func GetEndpointTypesByChannelType(channelType int, modelName string) []constant.EndpointType {
|
||||
var endpointTypes []constant.EndpointType
|
||||
switch channelType {
|
||||
case constant.ChannelTypeJina:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeJinaRerank}
|
||||
//case constant.ChannelTypeMidjourney, constant.ChannelTypeMidjourneyPlus:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeMidjourney}
|
||||
//case constant.ChannelTypeSunoAPI:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeSuno}
|
||||
//case constant.ChannelTypeKling:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeKling}
|
||||
//case constant.ChannelTypeJimeng:
|
||||
// endpointTypes = []constant.EndpointType{constant.EndpointTypeJimeng}
|
||||
case constant.ChannelTypeAws:
|
||||
fallthrough
|
||||
case constant.ChannelTypeAnthropic:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeAnthropic, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
fallthrough
|
||||
case constant.ChannelTypeGemini:
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeGemini, constant.EndpointTypeOpenAI}
|
||||
case constant.ChannelTypeOpenRouter: // OpenRouter 只支持 OpenAI 端点
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
default:
|
||||
if IsOpenAIResponseOnlyModel(modelName) {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAIResponse}
|
||||
} else {
|
||||
endpointTypes = []constant.EndpointType{constant.EndpointTypeOpenAI}
|
||||
}
|
||||
}
|
||||
if IsImageGenerationModel(modelName) {
|
||||
// add to first
|
||||
endpointTypes = append([]constant.EndpointType{constant.EndpointTypeImageGeneration}, endpointTypes...)
|
||||
}
|
||||
return endpointTypes
|
||||
}
|
||||
@@ -2,10 +2,11 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const KeyRequestBody = "key_request_body"
|
||||
@@ -31,7 +32,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
err = UnmarshalJson(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
@@ -43,3 +44,35 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetContextKey(c *gin.Context, key constant.ContextKey, value any) {
|
||||
c.Set(string(key), value)
|
||||
}
|
||||
|
||||
func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) {
|
||||
return c.Get(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyString(c *gin.Context, key constant.ContextKey) string {
|
||||
return c.GetString(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int {
|
||||
return c.GetInt(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool {
|
||||
return c.GetBool(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string {
|
||||
return c.GetStringSlice(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any {
|
||||
return c.GetStringMap(string(key))
|
||||
}
|
||||
|
||||
func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time {
|
||||
return c.GetTime(string(key))
|
||||
}
|
||||
|
||||
57
common/http.go
Normal file
57
common/http.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||
if httpResponse == nil || httpResponse.Body == nil {
|
||||
return
|
||||
}
|
||||
err := httpResponse.Body.Close()
|
||||
if err != nil {
|
||||
SysError("failed to close response body: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||
if c.Writer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
body := io.NopCloser(bytes.NewBuffer(data))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
if src != nil {
|
||||
for k, v := range src.Header {
|
||||
// avoid setting Content-Length
|
||||
if k == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
// set Content-Length header manually BEFORE calling WriteHeader
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
|
||||
// Write header with status code (this sends the headers)
|
||||
if src != nil {
|
||||
c.Writer.WriteHeader(src.StatusCode)
|
||||
} else {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Writer, body)
|
||||
if err != nil {
|
||||
LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -24,7 +25,7 @@ func printHelp() {
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func LoadEnv() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
@@ -95,4 +96,25 @@ func LoadEnv() {
|
||||
GlobalWebRateLimitEnable = GetEnvOrDefaultBool("GLOBAL_WEB_RATE_LIMIT_ENABLE", true)
|
||||
GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
|
||||
GlobalWebRateLimitDuration = int64(GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT_DURATION", 180))
|
||||
|
||||
initConstantEnv()
|
||||
}
|
||||
|
||||
func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
}
|
||||
|
||||
@@ -5,12 +5,16 @@ import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func DecodeJson(data []byte, v any) error {
|
||||
return json.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
func UnmarshalJson(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func DecodeJsonStr(data string, v any) error {
|
||||
return DecodeJson(StringToByteSlice(data), v)
|
||||
func UnmarshalJsonStr(data string, v any) error {
|
||||
return json.Unmarshal(StringToByteSlice(data), v)
|
||||
}
|
||||
|
||||
func DecodeJson(reader *bytes.Reader, v any) error {
|
||||
return json.NewDecoder(reader).Decode(v)
|
||||
}
|
||||
|
||||
func EncodeJson(v any) ([]byte, error) {
|
||||
|
||||
42
common/model.go
Normal file
42
common/model.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package common
|
||||
|
||||
import "strings"
|
||||
|
||||
var (
|
||||
// OpenAIResponseOnlyModels is a list of models that are only available for OpenAI responses.
|
||||
OpenAIResponseOnlyModels = []string{
|
||||
"o3-pro",
|
||||
"o3-deep-research",
|
||||
"o4-mini-deep-research",
|
||||
}
|
||||
ImageGenerationModels = []string{
|
||||
"dall-e-3",
|
||||
"dall-e-2",
|
||||
"gpt-image-1",
|
||||
"prefix:imagen-",
|
||||
"flux-",
|
||||
"flux.1-",
|
||||
}
|
||||
)
|
||||
|
||||
func IsOpenAIResponseOnlyModel(modelName string) bool {
|
||||
for _, m := range OpenAIResponseOnlyModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsImageGenerationModel(modelName string) bool {
|
||||
modelName = strings.ToLower(modelName)
|
||||
for _, m := range ImageGenerationModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(m, "prefix:") && strings.HasPrefix(modelName, strings.TrimPrefix(m, "prefix:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
62
common/page_info.go
Normal file
62
common/page_info.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type PageInfo struct {
|
||||
Page int `json:"page"` // page num 页码
|
||||
PageSize int `json:"page_size"` // page size 页大小
|
||||
StartTimestamp int64 `json:"start_timestamp"` // 秒级
|
||||
EndTimestamp int64 `json:"end_timestamp"` // 秒级
|
||||
|
||||
Total int `json:"total"` // 总条数,后设置
|
||||
Items any `json:"items"` // 数据,后设置
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetStartIdx() int {
|
||||
return (p.Page - 1) * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetEndIdx() int {
|
||||
return p.Page * p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPageSize() int {
|
||||
return p.PageSize
|
||||
}
|
||||
|
||||
func (p *PageInfo) GetPage() int {
|
||||
return p.Page
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetTotal(total int) {
|
||||
p.Total = total
|
||||
}
|
||||
|
||||
func (p *PageInfo) SetItems(items any) {
|
||||
p.Items = items
|
||||
}
|
||||
|
||||
func GetPageQuery(c *gin.Context) (*PageInfo, error) {
|
||||
pageInfo := &PageInfo{}
|
||||
err := c.BindQuery(pageInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pageInfo.Page < 1 {
|
||||
// 兼容
|
||||
page, _ := strconv.Atoi(c.Query("p"))
|
||||
if page != 0 {
|
||||
pageInfo.Page = page
|
||||
} else {
|
||||
pageInfo.Page = 1
|
||||
}
|
||||
}
|
||||
|
||||
if pageInfo.PageSize == 0 {
|
||||
pageInfo.PageSize = ItemsPerPage
|
||||
}
|
||||
return pageInfo, nil
|
||||
}
|
||||
@@ -16,6 +16,10 @@ import (
|
||||
var RDB *redis.Client
|
||||
var RedisEnabled = true
|
||||
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return SyncFrequency
|
||||
}
|
||||
|
||||
// InitRedisClient This function is called after init()
|
||||
func InitRedisClient() (err error) {
|
||||
if os.Getenv("REDIS_CONN_STRING") == "" {
|
||||
|
||||
26
constant/README.md
Normal file
26
constant/README.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# constant 包 (`/constant`)
|
||||
|
||||
该目录仅用于放置全局可复用的**常量定义**,不包含任何业务逻辑或依赖关系。
|
||||
|
||||
## 当前文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|----------------------|---------------------------------------------------------------------|
|
||||
| `azure.go` | 定义与 Azure 相关的全局常量,如 `AzureNoRemoveDotTime`(控制删除 `.` 的截止时间)。 |
|
||||
| `cache_key.go` | 缓存键格式字符串及 Token 相关字段常量,统一缓存命名规则。 |
|
||||
| `channel_setting.go` | Channel 级别的设置键,如 `proxy`、`force_format` 等。 |
|
||||
| `context_key.go` | 定义 `ContextKey` 类型以及在整个项目中使用的上下文键常量(请求时间、Token/Channel/User 相关信息等)。 |
|
||||
| `env.go` | 环境配置相关的全局变量,在启动阶段根据配置文件或环境变量注入。 |
|
||||
| `finish_reason.go` | OpenAI/GPT 请求返回的 `finish_reason` 字符串常量集合。 |
|
||||
| `midjourney.go` | Midjourney 相关错误码及动作(Action)常量与模型到动作的映射表。 |
|
||||
| `setup.go` | 标识项目是否已完成初始化安装 (`Setup` 布尔值)。 |
|
||||
| `task.go` | 各种任务(Task)平台、动作常量及模型与动作映射表,如 Suno、Midjourney 等。 |
|
||||
| `user_setting.go` | 用户设置相关键常量以及通知类型(Email/Webhook)等。 |
|
||||
|
||||
## 使用约定
|
||||
|
||||
1. `constant` 包**只能被其他包引用**(import),**禁止在此包中引用项目内的其他自定义包**。如确有需要,仅允许引用 **Go 标准库**。
|
||||
2. 不允许在此目录内编写任何与业务流程、数据库操作、第三方服务调用等相关的逻辑代码。
|
||||
3. 新增类型时,请保持命名语义清晰,并在本 README 的 **当前文件** 表格中补充说明,确保团队成员能够快速了解其用途。
|
||||
|
||||
> ⚠️ 违反以上约定将导致包之间产生不必要的耦合,影响代码可维护性与可测试性。请在提交代码前自行检查。
|
||||
34
constant/api_type.go
Normal file
34
constant/api_type.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@@ -1,12 +1,5 @@
|
||||
package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
// 使用函数来避免初始化顺序带来的赋值问题
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return common.SyncFrequency
|
||||
}
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
|
||||
109
constant/channel.go
Normal file
109
constant/channel.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ChannelTypeUnknown = 0
|
||||
ChannelTypeOpenAI = 1
|
||||
ChannelTypeMidjourney = 2
|
||||
ChannelTypeAzure = 3
|
||||
ChannelTypeOllama = 4
|
||||
ChannelTypeMidjourneyPlus = 5
|
||||
ChannelTypeOpenAIMax = 6
|
||||
ChannelTypeOhMyGPT = 7
|
||||
ChannelTypeCustom = 8
|
||||
ChannelTypeAILS = 9
|
||||
ChannelTypeAIProxy = 10
|
||||
ChannelTypePaLM = 11
|
||||
ChannelTypeAPI2GPT = 12
|
||||
ChannelTypeAIGC2D = 13
|
||||
ChannelTypeAnthropic = 14
|
||||
ChannelTypeBaidu = 15
|
||||
ChannelTypeZhipu = 16
|
||||
ChannelTypeAli = 17
|
||||
ChannelTypeXunfei = 18
|
||||
ChannelType360 = 19
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
ChannelTypeMoonshot = 25
|
||||
ChannelTypeZhipu_v4 = 26
|
||||
ChannelTypePerplexity = 27
|
||||
ChannelTypeLingYiWanWu = 31
|
||||
ChannelTypeAws = 33
|
||||
ChannelTypeCohere = 34
|
||||
ChannelTypeMiniMax = 35
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
ChannelTypeSiliconFlow = 40
|
||||
ChannelTypeVertexAi = 41
|
||||
ChannelTypeMistral = 42
|
||||
ChannelTypeDeepSeek = 43
|
||||
ChannelTypeMokaAI = 44
|
||||
ChannelTypeVolcEngine = 45
|
||||
ChannelTypeBaiduV2 = 46
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"http://localhost:11434", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://api.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.tencentcloudapi.com", //23
|
||||
"https://generativelanguage.googleapis.com", //24
|
||||
"https://api.moonshot.cn", //25
|
||||
"https://open.bigmodel.cn", //26
|
||||
"https://api.perplexity.ai", //27
|
||||
"", //28
|
||||
"", //29
|
||||
"", //30
|
||||
"https://api.lingyiwanwu.com", //31
|
||||
"", //32
|
||||
"", //33
|
||||
"https://api.cohere.ai", //34
|
||||
"https://api.minimax.chat", //35
|
||||
"", //36
|
||||
"https://api.dify.ai", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
"https://api.siliconflow.cn", //40
|
||||
"", //41
|
||||
"https://api.mistral.ai", //42
|
||||
"https://api.deepseek.com", //43
|
||||
"https://api.moka.ai", //44
|
||||
"https://ark.cn-beijing.volces.com", //45
|
||||
"https://qianfan.baidubce.com", //46
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
@@ -1,11 +1,35 @@
|
||||
package constant
|
||||
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
ContextKeyUserSetting = "user_setting"
|
||||
ContextKeyUserQuota = "user_quota"
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
ContextKeyUsingGroup = "group"
|
||||
ContextKeyOriginalModel ContextKey = "original_model"
|
||||
ContextKeyRequestStartTime ContextKey = "request_start_time"
|
||||
|
||||
/* token related keys */
|
||||
ContextKeyTokenUnlimited ContextKey = "token_unlimited_quota"
|
||||
ContextKeyTokenKey ContextKey = "token_key"
|
||||
ContextKeyTokenId ContextKey = "token_id"
|
||||
ContextKeyTokenGroup ContextKey = "token_group"
|
||||
ContextKeyTokenAllowIps ContextKey = "allow_ips"
|
||||
ContextKeyTokenSpecificChannelId ContextKey = "specific_channel_id"
|
||||
ContextKeyTokenModelLimitEnabled ContextKey = "token_model_limit_enabled"
|
||||
ContextKeyTokenModelLimit ContextKey = "token_model_limit"
|
||||
|
||||
/* channel related keys */
|
||||
ContextKeyBaseUrl ContextKey = "base_url"
|
||||
ContextKeyChannelType ContextKey = "channel_type"
|
||||
ContextKeyChannelId ContextKey = "channel_id"
|
||||
ContextKeyChannelSetting ContextKey = "channel_setting"
|
||||
ContextKeyParamOverride ContextKey = "param_override"
|
||||
|
||||
/* user related keys */
|
||||
ContextKeyUserId ContextKey = "id"
|
||||
ContextKeyUserSetting ContextKey = "user_setting"
|
||||
ContextKeyUserQuota ContextKey = "user_quota"
|
||||
ContextKeyUserStatus ContextKey = "user_status"
|
||||
ContextKeyUserEmail ContextKey = "user_email"
|
||||
ContextKeyUserGroup ContextKey = "user_group"
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
)
|
||||
|
||||
16
constant/endpoint_type.go
Normal file
16
constant/endpoint_type.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
type EndpointType string
|
||||
|
||||
const (
|
||||
EndpointTypeOpenAI EndpointType = "openai"
|
||||
EndpointTypeOpenAIResponse EndpointType = "openai-response"
|
||||
EndpointTypeAnthropic EndpointType = "anthropic"
|
||||
EndpointTypeGemini EndpointType = "gemini"
|
||||
EndpointTypeJinaRerank EndpointType = "jina-rerank"
|
||||
EndpointTypeImageGeneration EndpointType = "image-generation"
|
||||
//EndpointTypeMidjourney EndpointType = "midjourney-proxy"
|
||||
//EndpointTypeSuno EndpointType = "suno-proxy"
|
||||
//EndpointTypeKling EndpointType = "kling"
|
||||
//EndpointTypeJimeng EndpointType = "jimeng"
|
||||
)
|
||||
@@ -1,9 +1,5 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
@@ -17,39 +13,3 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
//var GeminiModelMap = map[string]string{
|
||||
// "gemini-1.0-pro": "v1",
|
||||
//}
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
|
||||
NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
|
||||
// GenerateDefaultToken 是否生成初始令牌,默认关闭。
|
||||
GenerateDefaultToken = common.GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
ErrorLogEnabled = common.GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
//if modelVersionMapStr == "" {
|
||||
// return
|
||||
//}
|
||||
//for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||
// parts := strings.Split(pair, ":")
|
||||
// if len(parts) == 2 {
|
||||
// GeminiModelMap[parts[0]] = parts[1]
|
||||
// } else {
|
||||
// common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ const (
|
||||
MjActionPan = "PAN"
|
||||
MjActionSwapFace = "SWAP_FACE"
|
||||
MjActionUpload = "UPLOAD"
|
||||
MjActionVideo = "VIDEO"
|
||||
MjActionEdits = "EDITS"
|
||||
)
|
||||
|
||||
var MidjourneyModel2Action = map[string]string{
|
||||
@@ -41,4 +43,6 @@ var MidjourneyModel2Action = map[string]string{
|
||||
"mj_pan": MjActionPan,
|
||||
"swap_face": MjActionSwapFace,
|
||||
"mj_upload": MjActionUpload,
|
||||
"mj_video": MjActionVideo,
|
||||
"mj_edits": MjActionEdits,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,9 @@ const (
|
||||
const (
|
||||
SunoActionMusic = "MUSIC"
|
||||
SunoActionLyrics = "LYRICS"
|
||||
|
||||
TaskActionGenerate = "generate"
|
||||
TaskActionTextGenerate = "textGenerate"
|
||||
)
|
||||
|
||||
var SunoModel2Action = map[string]string{
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
@@ -341,34 +342,34 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
channel.BaseURL = &baseURL
|
||||
}
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case constant.ChannelTypeOpenAI:
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
return 0, errors.New("尚未实现")
|
||||
case common.ChannelTypeCustom:
|
||||
case constant.ChannelTypeCustom:
|
||||
baseURL = channel.GetBaseURL()
|
||||
//case common.ChannelTypeOpenAISB:
|
||||
// return updateChannelOpenAISBBalance(channel)
|
||||
case common.ChannelTypeAIProxy:
|
||||
case constant.ChannelTypeAIProxy:
|
||||
return updateChannelAIProxyBalance(channel)
|
||||
case common.ChannelTypeAPI2GPT:
|
||||
case constant.ChannelTypeAPI2GPT:
|
||||
return updateChannelAPI2GPTBalance(channel)
|
||||
case common.ChannelTypeAIGC2D:
|
||||
case constant.ChannelTypeAIGC2D:
|
||||
return updateChannelAIGC2DBalance(channel)
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
case constant.ChannelTypeSiliconFlow:
|
||||
return updateChannelSiliconFlowBalance(channel)
|
||||
case common.ChannelTypeDeepSeek:
|
||||
case constant.ChannelTypeDeepSeek:
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case common.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
@@ -31,19 +31,19 @@ import (
|
||||
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
|
||||
tik := time.Now()
|
||||
if channel.Type == common.ChannelTypeMidjourney {
|
||||
if channel.Type == constant.ChannelTypeMidjourney {
|
||||
return errors.New("midjourney channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported!!!"), nil
|
||||
if channel.Type == constant.ChannelTypeMidjourneyPlus {
|
||||
return errors.New("midjourney plus channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
if channel.Type == constant.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeKling {
|
||||
if channel.Type == constant.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeJimeng {
|
||||
if channel.Type == constant.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
@@ -56,7 +56,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
strings.Contains(testModel, "bge-") || // bge 系列模型
|
||||
strings.Contains(testModel, "embed") ||
|
||||
channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
channel.Type == constant.ChannelTypeMokaAI { // 其他 embedding 模型
|
||||
requestPath = "/v1/embeddings" // 修改请求路径
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
}
|
||||
testModel = info.UpstreamModelName
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
apiType, _ := common.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
@@ -202,7 +202,7 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
testRequest.MaxTokens = 50
|
||||
}
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 300
|
||||
testRequest.MaxTokens = 3000
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -125,7 +126,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p - 1) * pageSize).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
@@ -181,15 +182,15 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
}
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
@@ -213,7 +214,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
var ids []string
|
||||
for _, model := range result.Data {
|
||||
id := model.ID
|
||||
if channel.Type == common.ChannelTypeGemini {
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
id = strings.TrimPrefix(id, "models/")
|
||||
}
|
||||
ids = append(ids, id)
|
||||
@@ -388,7 +389,7 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
channel.CreatedTime = common.GetTimestamp()
|
||||
keys := strings.Split(channel.Key, "\n")
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -613,7 +614,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type == common.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Other == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -668,7 +669,7 @@ func FetchModels(c *gin.Context) {
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = common.ChannelBaseURLs[req.Type]
|
||||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -14,10 +15,7 @@ import (
|
||||
"one-api/relay/channel/minimax"
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -26,30 +24,10 @@ var openAIModels []dto.OpenAIModels
|
||||
var openAIModelsMap map[string]dto.OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
func getPermission() []dto.OpenAIModelPermission {
|
||||
var permission []dto.OpenAIModelPermission
|
||||
permission = append(permission, dto.OpenAIModelPermission{
|
||||
Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ",
|
||||
Object: "model_permission",
|
||||
Created: 1626777600,
|
||||
AllowCreateEngine: true,
|
||||
AllowSampling: true,
|
||||
AllowLogprobs: true,
|
||||
AllowSearchIndices: false,
|
||||
AllowView: true,
|
||||
AllowFineTuning: false,
|
||||
Organization: "*",
|
||||
Group: nil,
|
||||
IsBlocking: false,
|
||||
})
|
||||
return permission
|
||||
}
|
||||
|
||||
func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
permission := getPermission()
|
||||
for i := 0; i < relayconstant.APITypeDummy; i++ {
|
||||
if i == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 0; i < constant.APITypeDummy; i++ {
|
||||
if i == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
adaptor := relay.GetAdaptor(i)
|
||||
@@ -57,69 +35,51 @@ func init() {
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: channelName,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, modelName := range ai360.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: ai360.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range moonshot.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: moonshot.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range lingyiwanwu.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: lingyiwanwu.ChannelName,
|
||||
})
|
||||
}
|
||||
for _, modelName := range minimax.ModelList {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: minimax.ChannelName,
|
||||
})
|
||||
}
|
||||
for modelName, _ := range constant.MidjourneyModel2Action {
|
||||
openAIModels = append(openAIModels, dto.OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
Permission: permission,
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "midjourney",
|
||||
})
|
||||
}
|
||||
openAIModelsMap = make(map[string]dto.OpenAIModels)
|
||||
@@ -127,9 +87,9 @@ func init() {
|
||||
openAIModelsMap[aiModel.Id] = aiModel
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
for i := 1; i <= common.ChannelTypeDummy; i++ {
|
||||
apiType, success := relayconstant.ChannelType2APIType(i)
|
||||
if !success || apiType == relayconstant.APITypeAIProxyLibrary {
|
||||
for i := 1; i <= constant.ChannelTypeDummy; i++ {
|
||||
apiType, success := common.ChannelType2APIType(i)
|
||||
if !success || apiType == constant.APITypeAIProxyLibrary {
|
||||
continue
|
||||
}
|
||||
meta := &relaycommon.RelayInfo{ChannelType: i}
|
||||
@@ -144,11 +104,10 @@ func init() {
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
userOpenAiModels := make([]dto.OpenAIModels, 0)
|
||||
permission := getPermission()
|
||||
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -156,17 +115,16 @@ func ListModels(c *gin.Context) {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
for allowModel, _ := range tokenModelLimit {
|
||||
if _, ok := openAIModelsMap[allowModel]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[allowModel])
|
||||
if oaiModel, ok := openAIModelsMap[allowModel]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(allowModel)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: allowModel,
|
||||
Parent: nil,
|
||||
Id: allowModel,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(allowModel),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -181,14 +139,14 @@ func ListModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
group := userGroup
|
||||
tokenGroup := c.GetString("token_group")
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
groupModels := model.GetGroupEnabledModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
@@ -196,20 +154,19 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
models = model.GetGroupEnabledModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
for _, modelName := range models {
|
||||
if oaiModel, ok := openAIModelsMap[modelName]; ok {
|
||||
oaiModel.SupportedEndpointTypes = model.GetModelSupportEndpointTypes(modelName)
|
||||
userOpenAiModels = append(userOpenAiModels, oaiModel)
|
||||
} else {
|
||||
userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{
|
||||
Id: s,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Permission: permission,
|
||||
Root: s,
|
||||
Parent: nil,
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
SupportedEndpointTypes: model.GetModelSupportEndpointTypes(modelName),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func Playground(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -69,7 +69,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
originalModel := c.GetString("original_model")
|
||||
@@ -132,7 +132,7 @@ func WssRelay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
relayMode := relayconstant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
group := c.GetString("group")
|
||||
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
|
||||
@@ -295,7 +295,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry
|
||||
}
|
||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||
channelType := c.GetInt("channel_type")
|
||||
if channelType == common.ChannelTypeAnthropic {
|
||||
if channelType == constant.ChannelTypeAnthropic {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -51,7 +51,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
@@ -246,15 +246,15 @@ func Register(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetAllUsers(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 1 {
|
||||
p = 1
|
||||
pageInfo, err := common.GetPageQuery(c)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "parse page query failed",
|
||||
})
|
||||
return
|
||||
}
|
||||
if pageSize < 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
|
||||
users, total, err := model.GetAllUsers(pageInfo)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -262,15 +262,13 @@ func GetAllUsers(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(users)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": users,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
"data": pageInfo,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -489,7 +487,7 @@ func GetUserModels(c *gin.Context) {
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
var models []string
|
||||
for group := range groups {
|
||||
for _, g := range model.GetGroupModels(group) {
|
||||
for _, g := range model.GetGroupEnabledModels(group) {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ services:
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- TZ=Asia/Shanghai
|
||||
- ERROR_LOG_ENABLED=true # 是否启用错误日志记录
|
||||
# - TIKTOKEN_CACHE_DIR=./tiktoken_cache # 如果需要使用tiktoken_cache,请取消注释
|
||||
# - STREAMING_TIMEOUT=120 # 流模式无响应超时时间,单位秒,默认120秒,如果出现空补全可以尝试改为更大值
|
||||
# - SESSION_SECRET=random_string # 多机部署时设置,必须修改这个随机字符串!!!!!!!
|
||||
# - NODE_TYPE=slave # Uncomment for slave node in multi-node deployment
|
||||
# - SYNC_FREQUENCY=60 # Uncomment if regular database syncing is needed
|
||||
|
||||
@@ -57,6 +57,8 @@ type MidjourneyDto struct {
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
VideoUrl string `json:"videoUrl"`
|
||||
VideoUrls []ImgUrls `json:"videoUrls"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
@@ -65,6 +67,10 @@ type MidjourneyDto struct {
|
||||
Properties *Properties `json:"properties"`
|
||||
}
|
||||
|
||||
type ImgUrls struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
|
||||
type MidjourneyStatus struct {
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct {
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
data, _ := common.EncodeJson(r)
|
||||
_ = common.DecodeJson(data, &result)
|
||||
_ = common.UnmarshalJson(data, &result)
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
@@ -1,26 +1,11 @@
|
||||
package dto
|
||||
|
||||
type OpenAIModelPermission struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
AllowCreateEngine bool `json:"allow_create_engine"`
|
||||
AllowSampling bool `json:"allow_sampling"`
|
||||
AllowLogprobs bool `json:"allow_logprobs"`
|
||||
AllowSearchIndices bool `json:"allow_search_indices"`
|
||||
AllowView bool `json:"allow_view"`
|
||||
AllowFineTuning bool `json:"allow_fine_tuning"`
|
||||
Organization string `json:"organization"`
|
||||
Group *string `json:"group"`
|
||||
IsBlocking bool `json:"is_blocking"`
|
||||
}
|
||||
import "one-api/constant"
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n"`
|
||||
TopN int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
|
||||
1041
i18n/zh-cn.json
Normal file
1041
i18n/zh-cn.json
Normal file
File diff suppressed because it is too large
Load Diff
85
main.go
85
main.go
@@ -32,12 +32,12 @@ var buildFS embed.FS
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysLog("Support for .env file is disabled: " + err.Error())
|
||||
}
|
||||
|
||||
common.LoadEnv()
|
||||
err := InitResources()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize resources: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
common.SetupLogger()
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
@@ -47,19 +47,7 @@ func main() {
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("running in debug mode")
|
||||
}
|
||||
// Initialize SQL Database
|
||||
err = model.InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
}
|
||||
defer func() {
|
||||
err := model.CloseDB()
|
||||
if err != nil {
|
||||
@@ -67,21 +55,6 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize Redis
|
||||
err = common.InitRedisClient()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
ratio_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
model.InitOptionMap()
|
||||
|
||||
service.InitTokenEncoders()
|
||||
|
||||
if common.RedisEnabled {
|
||||
// for compatibility with old versions
|
||||
common.MemoryCacheEnabled = true
|
||||
@@ -186,3 +159,51 @@ func main() {
|
||||
common.FatalLog("failed to start HTTP server: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func InitResources() error {
|
||||
// Initialize resources here if needed
|
||||
// This is a placeholder function for future resource initialization
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysLog("未找到 .env 文件,使用默认环境变量,如果需要,请创建 .env 文件并设置相关变量")
|
||||
common.SysLog("No .env file found, using default environment variables. If needed, please create a .env file and set the relevant variables.")
|
||||
}
|
||||
|
||||
// 加载环境变量
|
||||
common.InitEnv()
|
||||
|
||||
// Initialize model settings
|
||||
ratio_setting.InitRatioSettings()
|
||||
|
||||
service.InitHttpClient()
|
||||
|
||||
service.InitTokenEncoders()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitDB()
|
||||
if err != nil {
|
||||
common.FatalLog("failed to initialize database: " + err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
model.CheckSetup()
|
||||
|
||||
// Initialize options, should after model.InitDB()
|
||||
model.InitOptionMap()
|
||||
|
||||
// 初始化模型
|
||||
model.GetPricing()
|
||||
|
||||
// Initialize SQL Database
|
||||
err = model.InitLogDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Redis
|
||||
err = common.InitRedisClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ type ModelRequest struct {
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
allowIpsMap := c.GetStringMap("allow_ips")
|
||||
allowIpsMap := common.GetContextKeyStringMap(c, constant.ContextKeyTokenAllowIps)
|
||||
if len(allowIpsMap) != 0 {
|
||||
clientIp := c.ClientIP()
|
||||
if _, ok := allowIpsMap[clientIp]; !ok {
|
||||
@@ -34,14 +34,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId)
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup := c.GetString(constant.ContextKeyUserGroup)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
tokenGroup := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
@@ -57,7 +57,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set(constant.ContextKeyUsingGroup, userGroup)
|
||||
common.SetContextKey(c, constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -76,9 +76,9 @@ func Distribute() func(c *gin.Context) {
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled)
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit)
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
@@ -121,7 +121,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
@@ -261,21 +261,21 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeVertexAi:
|
||||
case constant.ChannelTypeVertexAi:
|
||||
c.Set("region", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
case constant.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
case constant.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
case constant.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeMokaAI:
|
||||
case constant.ChannelTypeMokaAI:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeCoze:
|
||||
case constant.ChannelTypeCoze:
|
||||
c.Set("bot_id", channel.Other)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func KlingRequestConvert() func(c *gin.Context) {
|
||||
@@ -35,7 +37,7 @@ func KlingRequestConvert() func(c *gin.Context) {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
if image := originalReq["image"]; image == "" {
|
||||
c.Set("action", "textGenerate")
|
||||
c.Set("action", constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// We have to reset the request body for the next handlers
|
||||
|
||||
@@ -177,9 +177,9 @@ func ModelRequestRateLimit() func(c *gin.Context) {
|
||||
successMaxCount := setting.ModelRequestRateLimitSuccessCount
|
||||
|
||||
// 获取分组
|
||||
group := c.GetString("token_group")
|
||||
group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
|
||||
if group == "" {
|
||||
group = c.GetString(constant.ContextKeyUserGroup)
|
||||
group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
|
||||
}
|
||||
|
||||
//获取分组的限流配置
|
||||
|
||||
@@ -21,7 +21,22 @@ type Ability struct {
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
type AbilityWithChannel struct {
|
||||
Ability
|
||||
ChannelType int `json:"channel_type"`
|
||||
}
|
||||
|
||||
func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
||||
var abilities []AbilityWithChannel
|
||||
err := DB.Table("abilities").
|
||||
Select("abilities.*, channels.type as channel_type").
|
||||
Joins("left join channels on abilities.channel_id = channels.id").
|
||||
Where("abilities.enabled = ?", true).
|
||||
Scan(&abilities).Error
|
||||
return abilities, err
|
||||
}
|
||||
|
||||
func GetGroupEnabledModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
@@ -46,7 +61,7 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
@@ -72,14 +87,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ type Midjourney struct {
|
||||
StartTime int64 `json:"start_time" gorm:"index"`
|
||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
VideoUrls string `json:"video_urls"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
|
||||
110
model/pricing.go
110
model/pricing.go
@@ -1,20 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Pricing struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups,omitempty"`
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
OwnerBy string `json:"owner_by"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
EnableGroup []string `json:"enable_groups"`
|
||||
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -23,47 +27,89 @@ var (
|
||||
updatePricingLock sync.Mutex
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
var (
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
modelSupportEndpointsLock = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func GetPricing() []Pricing {
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
updatePricing()
|
||||
updatePricingLock.Lock()
|
||||
defer updatePricingLock.Unlock()
|
||||
// Double check after acquiring the lock
|
||||
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
|
||||
modelSupportEndpointsLock.Lock()
|
||||
defer modelSupportEndpointsLock.Unlock()
|
||||
updatePricing()
|
||||
}
|
||||
}
|
||||
//if group != "" {
|
||||
// userPricingMap := make([]Pricing, 0)
|
||||
// models := GetGroupModels(group)
|
||||
// for _, pricing := range pricingMap {
|
||||
// if !common.StringsContains(models, pricing.ModelName) {
|
||||
// pricing.Available = false
|
||||
// }
|
||||
// userPricingMap = append(userPricingMap, pricing)
|
||||
// }
|
||||
// return userPricingMap
|
||||
//}
|
||||
return pricingMap
|
||||
}
|
||||
|
||||
func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
|
||||
if model == "" {
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
modelSupportEndpointsLock.RLock()
|
||||
defer modelSupportEndpointsLock.RUnlock()
|
||||
if endpoints, ok := modelSupportEndpointTypes[model]; ok {
|
||||
return endpoints
|
||||
}
|
||||
return make([]constant.EndpointType, 0)
|
||||
}
|
||||
|
||||
func updatePricing() {
|
||||
//modelRatios := common.GetModelRatios()
|
||||
enableAbilities := GetAllEnableAbilities()
|
||||
modelGroupsMap := make(map[string][]string)
|
||||
enableAbilities, err := GetAllEnableAbilityWithChannels()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
|
||||
return
|
||||
}
|
||||
modelGroupsMap := make(map[string]*types.Set[string])
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
groups := modelGroupsMap[ability.Model]
|
||||
if groups == nil {
|
||||
groups = make([]string, 0)
|
||||
groups, ok := modelGroupsMap[ability.Model]
|
||||
if !ok {
|
||||
groups = types.NewSet[string]()
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
}
|
||||
if !common.StringsContains(groups, ability.Group) {
|
||||
groups = append(groups, ability.Group)
|
||||
groups.Add(ability.Group)
|
||||
}
|
||||
|
||||
//这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
|
||||
modelSupportEndpointsStr := make(map[string][]string)
|
||||
|
||||
for _, ability := range enableAbilities {
|
||||
endpoints, ok := modelSupportEndpointsStr[ability.Model]
|
||||
if !ok {
|
||||
endpoints = make([]string, 0)
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
modelGroupsMap[ability.Model] = groups
|
||||
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
|
||||
for _, channelType := range channelTypes {
|
||||
if !common.StringsContains(endpoints, string(channelType)) {
|
||||
endpoints = append(endpoints, string(channelType))
|
||||
}
|
||||
}
|
||||
modelSupportEndpointsStr[ability.Model] = endpoints
|
||||
}
|
||||
|
||||
modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
|
||||
for model, endpoints := range modelSupportEndpointsStr {
|
||||
supportedEndpoints := make([]constant.EndpointType, 0)
|
||||
for _, endpointStr := range endpoints {
|
||||
endpointType := constant.EndpointType(endpointStr)
|
||||
supportedEndpoints = append(supportedEndpoints, endpointType)
|
||||
}
|
||||
modelSupportEndpointTypes[model] = supportedEndpoints
|
||||
}
|
||||
|
||||
pricingMap = make([]Pricing, 0)
|
||||
for model, groups := range modelGroupsMap {
|
||||
pricing := Pricing{
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
ModelName: model,
|
||||
EnableGroup: groups.Items(),
|
||||
SupportedEndpointTypes: modelSupportEndpointTypes[model],
|
||||
}
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(common.RedisKeyCacheSeconds())*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func GetMaxUserId() int {
|
||||
return user.Id
|
||||
}
|
||||
|
||||
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
|
||||
func GetAllUsers(pageInfo *common.PageInfo) (users []*User, total int64, err error) {
|
||||
// Start transaction
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
@@ -134,7 +134,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error)
|
||||
}
|
||||
|
||||
// Get paginated users within same transaction
|
||||
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
||||
err = tx.Unscoped().Order("id desc").Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("password").Find(&users).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, 0, err
|
||||
|
||||
@@ -24,12 +24,12 @@ type UserBase struct {
|
||||
}
|
||||
|
||||
func (user *UserBase) WriteContext(c *gin.Context) {
|
||||
c.Set(constant.ContextKeyUserGroup, user.Group)
|
||||
c.Set(constant.ContextKeyUserQuota, user.Quota)
|
||||
c.Set(constant.ContextKeyUserStatus, user.Status)
|
||||
c.Set(constant.ContextKeyUserEmail, user.Email)
|
||||
c.Set("username", user.Username)
|
||||
c.Set(constant.ContextKeyUserSetting, user.GetSetting())
|
||||
common.SetContextKey(c, constant.ContextKeyUserGroup, user.Group)
|
||||
common.SetContextKey(c, constant.ContextKeyUserQuota, user.Quota)
|
||||
common.SetContextKey(c, constant.ContextKeyUserStatus, user.Status)
|
||||
common.SetContextKey(c, constant.ContextKeyUserEmail, user.Email)
|
||||
common.SetContextKey(c, constant.ContextKeyUserName, user.Username)
|
||||
common.SetContextKey(c, constant.ContextKeyUserSetting, user.GetSetting())
|
||||
}
|
||||
|
||||
func (user *UserBase) GetSetting() map[string]interface{} {
|
||||
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
|
||||
return common.RedisHSetObj(
|
||||
getUserCacheKey(user.Id),
|
||||
user.ToBaseUser(),
|
||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
||||
time.Duration(common.RedisKeyCacheSeconds())*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
var fullRequestURL string
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeEmbeddings:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
@@ -82,7 +82,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return embeddingRequestOpenAI2Ali(request), nil
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
|
||||
@@ -132,10 +132,7 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliTaskResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
@@ -35,10 +36,7 @@ func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var aliResponse AliRerankResponse
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
|
||||
@@ -39,34 +39,18 @@ func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingReque
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
var fullTextResponse dto.OpenAIEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
model := c.GetString("model")
|
||||
if model == "" {
|
||||
model = "text-embedding-v4"
|
||||
}
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse, model)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -186,10 +170,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -199,10 +180,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -166,10 +166,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
@@ -179,10 +176,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -215,10 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -280,7 +271,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := service.GetImpatientHttpClient().Do(req)
|
||||
res, err := service.GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
|
||||
if textRequest.Reasoning != nil {
|
||||
var reasoning openrouter.RequestReasoning
|
||||
if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||
if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
|
||||
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.DecodeJsonStr(data, &claudeResponse)
|
||||
err := common.UnmarshalJsonStr(data, &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
|
||||
@@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.DecodeJson(data, &claudeResponse)
|
||||
err := common.UnmarshalJson(data, &claudeResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
case relaycommon.RelayFormatClaude:
|
||||
responseData = data
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, err = c.Writer.Write(responseData)
|
||||
|
||||
common.IOCopyBytesGracefully(c, nil, responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
claudeInfo := &ClaudeResponseInfo{
|
||||
ResponseId: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Created: common.GetTimestamp(),
|
||||
@@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
|
||||
@@ -81,10 +81,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
}
|
||||
helper.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
@@ -94,10 +91,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
@@ -127,10 +121,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &cfResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -173,10 +173,7 @@ func cohereHandler(c *gin.Context, resp *http.Response, modelName string, prompt
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
@@ -217,10 +214,7 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var cohereResp CohereRerankResponseResult
|
||||
err = json.Unmarshal(responseBody, &cohereResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -48,10 +48,7 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
// convert coze response to openai response
|
||||
var response dto.TextResponse
|
||||
var cozeResponse CozeChatDetailResponse
|
||||
|
||||
@@ -95,7 +95,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
|
||||
// Send request
|
||||
client := service.GetImpatientHttpClient()
|
||||
client := service.GetHttpClient()
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysError("failed to send request: " + err.Error())
|
||||
@@ -257,10 +257,7 @@ func difyHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInf
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &difyResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -15,15 +14,13 @@ import (
|
||||
)
|
||||
|
||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
@@ -31,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -54,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := json.Marshal(geminiResponse)
|
||||
jsonResponse, err := common.EncodeJson(geminiResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 设置响应头并写入响应
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
|
||||
return &usage, nil
|
||||
}
|
||||
@@ -80,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
|
||||
@@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
@@ -866,15 +866,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println(string(responseBody))
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -920,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
}
|
||||
|
||||
func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var geminiResponse GeminiEmbeddingResponse
|
||||
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
|
||||
@@ -956,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
|
||||
}
|
||||
openAIResponse.Usage = *usage.(*dto.Usage)
|
||||
|
||||
jsonResponse, jsonErr := json.Marshal(openAIResponse)
|
||||
jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
|
||||
if jsonErr != nil {
|
||||
return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
|
||||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package jina
|
||||
var ModelList = []string{
|
||||
"jina-clip-v1",
|
||||
"jina-reranker-v2-base-multilingual",
|
||||
"jina-reranker-m0",
|
||||
}
|
||||
|
||||
var ChannelName = "jina"
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
@@ -26,7 +27,7 @@ func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.Embeddin
|
||||
}
|
||||
return &dto.EmbeddingRequest{
|
||||
Input: input,
|
||||
Model: request.Model,
|
||||
Model: request.Model,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,10 +54,7 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -80,4 +78,3 @@ func mokaEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
@@ -88,10 +88,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &ollamaEmbeddingResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -120,31 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
// Copy headers
|
||||
for k, v := range resp.Header {
|
||||
// 删除任何现有的相同头部,以防止重复添加头部
|
||||
c.Writer.Header().Del(k)
|
||||
for _, vv := range v {
|
||||
c.Writer.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Del("Content-Length")
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, doResponseBody)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,7 @@ import (
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ai360"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/common_handler"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -54,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
|
||||
// initialize ThinkingContentInfo when thinking_to_content is enabled
|
||||
if think2Content, ok := info.ChannelSetting[constant2.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||
if think2Content, ok := info.ChannelSetting[constant.ChannelSettingThinkingToContent].(bool); ok && think2Content {
|
||||
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
|
||||
IsFirstThinkingContent: true,
|
||||
SendLastThinkingContent: false,
|
||||
@@ -67,7 +66,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
if strings.HasPrefix(info.BaseUrl, "https://") {
|
||||
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
||||
baseUrl = "wss://" + baseUrl
|
||||
@@ -79,10 +78,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
}
|
||||
}
|
||||
switch info.ChannelType {
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
apiVersion := info.ApiVersion
|
||||
if apiVersion == "" {
|
||||
apiVersion = constant2.AzureDefaultAPIVersion
|
||||
apiVersion = constant.AzureDefaultAPIVersion
|
||||
}
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
@@ -90,25 +89,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == constant.RelayModeResponses {
|
||||
if info.RelayMode == relayconstant.RelayModeResponses {
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
model_ := info.UpstreamModelName
|
||||
// 2025年5月10日后创建的渠道不移除.
|
||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
||||
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
|
||||
model_ = strings.Replace(model_, ".", "", -1)
|
||||
}
|
||||
// https://github.com/songquanpeng/one-api/issues/67
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
|
||||
}
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
case common.ChannelTypeMiniMax:
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.GetRequestURL(info)
|
||||
case common.ChannelTypeCustom:
|
||||
case constant.ChannelTypeCustom:
|
||||
url := info.BaseUrl
|
||||
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
||||
return url, nil
|
||||
@@ -119,14 +118,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, header)
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
if info.ChannelType == constant.ChannelTypeAzure {
|
||||
header.Set("api-key", info.ApiKey)
|
||||
return nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
||||
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
|
||||
header.Set("OpenAI-Organization", info.Organization)
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeRealtime {
|
||||
if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
||||
if swp != "" {
|
||||
items := []string{
|
||||
@@ -145,7 +144,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *
|
||||
} else {
|
||||
header.Set("Authorization", "Bearer "+info.ApiKey)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
header.Set("HTTP-Referer", "https://github.com/Calcium-Ion/new-api")
|
||||
header.Set("X-Title", "New API")
|
||||
}
|
||||
@@ -156,10 +155,10 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||||
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
if len(request.Usage) == 0 {
|
||||
request.Usage = json.RawMessage(`{"include":true}`)
|
||||
}
|
||||
@@ -205,7 +204,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
a.ResponseFormat = request.ResponseFormat
|
||||
if info.RelayMode == constant.RelayModeAudioSpeech {
|
||||
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshalling object: %w", err)
|
||||
@@ -254,7 +253,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesEdits:
|
||||
case relayconstant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
@@ -411,11 +410,11 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeAudioTranscription ||
|
||||
info.RelayMode == constant.RelayModeAudioTranslation ||
|
||||
info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
|
||||
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
|
||||
info.RelayMode == relayconstant.RelayModeImagesEdits {
|
||||
return channel.DoFormRequest(a, c, info, requestBody)
|
||||
} else if info.RelayMode == constant.RelayModeRealtime {
|
||||
} else if info.RelayMode == relayconstant.RelayModeRealtime {
|
||||
return channel.DoWssRequest(a, c, info, requestBody)
|
||||
} else {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
@@ -424,19 +423,19 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeRealtime:
|
||||
case relayconstant.RelayModeRealtime:
|
||||
err, usage = OpenaiRealtimeHandler(c, info)
|
||||
case constant.RelayModeAudioSpeech:
|
||||
case relayconstant.RelayModeAudioSpeech:
|
||||
err, usage = OpenaiTTSHandler(c, resp, info)
|
||||
case constant.RelayModeAudioTranslation:
|
||||
case relayconstant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
case relayconstant.RelayModeAudioTranscription:
|
||||
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
|
||||
err, usage = OpenaiHandlerWithUsage(c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
case relayconstant.RelayModeRerank:
|
||||
err, usage = common_handler.RerankHandler(c, info, resp)
|
||||
case constant.RelayModeResponses:
|
||||
case relayconstant.RelayModeResponses:
|
||||
if info.IsStream {
|
||||
err, usage = OaiResponsesStreamHandler(c, resp, info)
|
||||
} else {
|
||||
@@ -454,17 +453,17 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
switch a.ChannelType {
|
||||
case common.ChannelType360:
|
||||
case constant.ChannelType360:
|
||||
return ai360.ModelList
|
||||
case common.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ModelList
|
||||
case common.ChannelTypeLingYiWanWu:
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ModelList
|
||||
case common.ChannelTypeMiniMax:
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.ModelList
|
||||
case common.ChannelTypeXinference:
|
||||
case constant.ChannelTypeXinference:
|
||||
return xinference.ModelList
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return openrouter.ModelList
|
||||
default:
|
||||
return ModelList
|
||||
@@ -473,17 +472,17 @@ func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
switch a.ChannelType {
|
||||
case common.ChannelType360:
|
||||
case constant.ChannelType360:
|
||||
return ai360.ChannelName
|
||||
case common.ChannelTypeMoonshot:
|
||||
case constant.ChannelTypeMoonshot:
|
||||
return moonshot.ChannelName
|
||||
case common.ChannelTypeLingYiWanWu:
|
||||
case constant.ChannelTypeLingYiWanWu:
|
||||
return lingyiwanwu.ChannelName
|
||||
case common.ChannelTypeMiniMax:
|
||||
case constant.ChannelTypeMiniMax:
|
||||
return minimax.ChannelName
|
||||
case common.ChannelTypeXinference:
|
||||
case constant.ChannelTypeXinference:
|
||||
return xinference.ChannelName
|
||||
case common.ChannelTypeOpenRouter:
|
||||
case constant.ChannelTypeOpenRouter:
|
||||
return openrouter.ChannelName
|
||||
default:
|
||||
return ChannelName
|
||||
|
||||
@@ -2,7 +2,6 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@@ -34,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
|
||||
}
|
||||
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
|
||||
if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -111,12 +110,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
containStreamUsage := false
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
model := info.UpstreamModelName
|
||||
var responseId string
|
||||
var createAt int64 = 0
|
||||
var systemFingerprint string
|
||||
model := info.UpstreamModelName
|
||||
|
||||
var containStreamUsage bool
|
||||
var responseTextBuilder strings.Builder
|
||||
var toolCount int
|
||||
var usage = &dto.Usage{}
|
||||
@@ -148,31 +148,15 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
return true
|
||||
})
|
||||
|
||||
// 处理最后的响应
|
||||
shouldSendLastResp := true
|
||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||
err := common.DecodeJsonStr(lastStreamData, &lastStreamResponse)
|
||||
if err == nil {
|
||||
responseId = lastStreamResponse.Id
|
||||
createAt = lastStreamResponse.Created
|
||||
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||||
model = lastStreamResponse.Model
|
||||
if service.ValidUsage(lastStreamResponse.Usage) {
|
||||
containStreamUsage = true
|
||||
usage = lastStreamResponse.Usage
|
||||
if !info.ShouldIncludeUsage {
|
||||
shouldSendLastResp = false
|
||||
}
|
||||
}
|
||||
for _, choice := range lastStreamResponse.Choices {
|
||||
if choice.FinishReason != nil {
|
||||
shouldSendLastResp = true
|
||||
}
|
||||
}
|
||||
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
|
||||
&containStreamUsage, info, &shouldSendLastResp); err != nil {
|
||||
common.SysError("error handling last response: " + err.Error())
|
||||
}
|
||||
|
||||
if shouldSendLastResp {
|
||||
sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
|
||||
}
|
||||
|
||||
// 处理token计算
|
||||
@@ -184,7 +168,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||
if info.ChannelType == constant.ChannelTypeDeepSeek {
|
||||
if usage.PromptCacheHitTokens != 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
@@ -197,16 +181,14 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var simpleResponse dto.OpenAITextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = common.DecodeJson(responseBody, &simpleResponse)
|
||||
err = common.UnmarshalJson(responseBody, &simpleResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -238,7 +220,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
switch info.RelayFormat {
|
||||
case relaycommon.RelayFormatOpenAI:
|
||||
if forceFormat {
|
||||
responseBody, err = json.Marshal(simpleResponse)
|
||||
responseBody, err = common.EncodeJson(simpleResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -247,29 +229,15 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
}
|
||||
case relaycommon.RelayFormatClaude:
|
||||
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
|
||||
claudeRespStr, err := json.Marshal(claudeResp)
|
||||
claudeRespStr, err := common.EncodeJson(claudeResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
responseBody = claudeRespStr
|
||||
}
|
||||
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("error copying response body: " + err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
return nil, &simpleResponse.Usage
|
||||
}
|
||||
|
||||
@@ -280,7 +248,7 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer resp.Body.Close()
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.TotalTokens = info.PromptTokens
|
||||
@@ -297,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
@@ -306,25 +276,8 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = audioTokens
|
||||
@@ -415,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
}
|
||||
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = json.Unmarshal(message, realtimeEvent)
|
||||
err = common.UnmarshalJson(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
@@ -475,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
}
|
||||
info.SetFirstResponseTime()
|
||||
realtimeEvent := &dto.RealtimeEvent{}
|
||||
err = json.Unmarshal(message, realtimeEvent)
|
||||
err = common.UnmarshalJson(message, realtimeEvent)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
|
||||
return
|
||||
@@ -522,9 +475,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
|
||||
localUsage = &dto.RealtimeUsage{}
|
||||
// print now usage
|
||||
}
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
//common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
|
||||
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
|
||||
|
||||
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
|
||||
realtimeSession := realtimeEvent.Session
|
||||
@@ -601,40 +554,25 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
|
||||
}
|
||||
|
||||
func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
|
||||
var usageResp dto.SimpleResponse
|
||||
err = common.UnmarshalJson(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
// reset content length
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(responseBody)))
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error copying response body: " + err.Error())
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// Once we've written to the client, we should not return errors anymore
|
||||
// because the upstream has already consumed resources and returned content
|
||||
// We should still perform billing even if parsing fails
|
||||
var usageResp dto.SimpleResponse
|
||||
err = json.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
// format
|
||||
if usageResp.InputTokens > 0 {
|
||||
usageResp.PromptTokens += usageResp.InputTokens
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -16,17 +15,15 @@ import (
|
||||
)
|
||||
|
||||
func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// read response body
|
||||
var responsesResponse dto.OpenAIResponsesResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = common.DecodeJson(responseBody, &responsesResponse)
|
||||
err = common.UnmarshalJson(responseBody, &responsesResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -41,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
// copy response body
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
common.SysError("error copying response body: " + err.Error())
|
||||
}
|
||||
resp.Body.Close()
|
||||
// 写入新的 response body
|
||||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
// compute usage
|
||||
usage := dto.Usage{}
|
||||
usage.PromptTokens = responsesResponse.Usage.InputTokens
|
||||
@@ -82,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
||||
|
||||
// 检查当前数据是否包含 completed 状态和 usage 信息
|
||||
var streamResponse dto.ResponsesStreamResponse
|
||||
if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
|
||||
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
|
||||
@@ -83,12 +83,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("error closing stream response: " + err.Error())
|
||||
stopChan <- true
|
||||
return
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
@@ -122,10 +117,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
@@ -134,10 +126,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/service"
|
||||
)
|
||||
@@ -14,10 +15,7 @@ func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIE
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var siliconflowResp SFRerankResponse
|
||||
err = json.Unmarshal(responseBody, &siliconflowResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -77,8 +78,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
|
||||
// apiKey format: "access_key,secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, ",")
|
||||
// apiKey format: "access_key|secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
@@ -88,7 +89,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
@@ -192,9 +193,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyParts := strings.Split(key, ",")
|
||||
keyParts := strings.Split(key, "|")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak,sk'")
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -92,7 +93,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
@@ -112,7 +113,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
@@ -198,7 +199,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid action")
|
||||
}
|
||||
path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
|
||||
@@ -124,10 +124,7 @@ func tencentStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIError
|
||||
|
||||
helper.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
@@ -138,10 +135,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &tencentSb)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"fmt"
|
||||
@@ -45,7 +46,7 @@ func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
newToken, err := exchangeJwtForAccessToken(signedJWT)
|
||||
newToken, err := exchangeJwtForAccessToken(signedJWT, info)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
|
||||
}
|
||||
@@ -96,14 +97,25 @@ func createSignedJWT(email, privateKeyPEM string) (string, error) {
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
|
||||
func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
resp, err := http.PostForm(authURL, data)
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|
||||
client, err = service.NewProxyHttpClient(proxyURL.(string))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
|
||||
resp, err := client.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -13,6 +11,8 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
|
||||
@@ -73,18 +73,16 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
//return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
defer common.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
var response *dto.TextResponse
|
||||
err = common.DecodeJson(responseBody, &response)
|
||||
var response *dto.SimpleResponse
|
||||
err = common.UnmarshalJson(responseBody, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return nil, nil
|
||||
@@ -99,21 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// set new body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, resp, encodeJson)
|
||||
|
||||
return nil, &response.Usage
|
||||
}
|
||||
|
||||
@@ -210,10 +210,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWi
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -223,10 +220,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
|
||||
@@ -113,17 +113,17 @@ type RelayInfo struct {
|
||||
|
||||
// 定义支持流式选项的通道类型
|
||||
var streamSupportedChannels = map[int]bool{
|
||||
common.ChannelTypeOpenAI: true,
|
||||
common.ChannelTypeAnthropic: true,
|
||||
common.ChannelTypeAws: true,
|
||||
common.ChannelTypeGemini: true,
|
||||
common.ChannelCloudflare: true,
|
||||
common.ChannelTypeAzure: true,
|
||||
common.ChannelTypeVolcEngine: true,
|
||||
common.ChannelTypeOllama: true,
|
||||
common.ChannelTypeXai: true,
|
||||
common.ChannelTypeDeepSeek: true,
|
||||
common.ChannelTypeBaiduV2: true,
|
||||
constant.ChannelTypeOpenAI: true,
|
||||
constant.ChannelTypeAnthropic: true,
|
||||
constant.ChannelTypeAws: true,
|
||||
constant.ChannelTypeGemini: true,
|
||||
constant.ChannelCloudflare: true,
|
||||
constant.ChannelTypeAzure: true,
|
||||
constant.ChannelTypeVolcEngine: true,
|
||||
constant.ChannelTypeOllama: true,
|
||||
constant.ChannelTypeXai: true,
|
||||
constant.ChannelTypeDeepSeek: true,
|
||||
constant.ChannelTypeBaiduV2: true,
|
||||
}
|
||||
|
||||
func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
|
||||
@@ -211,40 +211,40 @@ func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelSetting := c.GetStringMap("channel_setting")
|
||||
paramOverride := c.GetStringMap("param_override")
|
||||
channelType := common.GetContextKeyInt(c, constant.ContextKeyChannelType)
|
||||
channelId := common.GetContextKeyInt(c, constant.ContextKeyChannelId)
|
||||
channelSetting := common.GetContextKeyStringMap(c, constant.ContextKeyChannelSetting)
|
||||
paramOverride := common.GetContextKeyStringMap(c, constant.ContextKeyParamOverride)
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenKey := c.GetString("token_key")
|
||||
userId := c.GetInt("id")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
tokenId := common.GetContextKeyInt(c, constant.ContextKeyTokenId)
|
||||
tokenKey := common.GetContextKeyString(c, constant.ContextKeyTokenKey)
|
||||
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
apiType, _ := common.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
UserQuota: c.GetInt(constant.ContextKeyUserQuota),
|
||||
UserSetting: c.GetStringMap(constant.ContextKeyUserSetting),
|
||||
UserEmail: c.GetString(constant.ContextKeyUserEmail),
|
||||
UserQuota: common.GetContextKeyInt(c, constant.ContextKeyUserQuota),
|
||||
UserSetting: common.GetContextKeyStringMap(c, constant.ContextKeyUserSetting),
|
||||
UserEmail: common.GetContextKeyString(c, constant.ContextKeyUserEmail),
|
||||
isFirstResponse: true,
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
BaseUrl: common.GetContextKeyString(c, constant.ContextKeyBaseUrl),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
ChannelId: channelId,
|
||||
TokenId: tokenId,
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
|
||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
OriginModelName: c.GetString("original_model"),
|
||||
UpstreamModelName: c.GetString("original_model"),
|
||||
OriginModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
|
||||
//RecodeModelName: c.GetString("original_model"),
|
||||
IsModelMapped: false,
|
||||
ApiType: apiType,
|
||||
@@ -266,12 +266,12 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
info.RequestURLPath = "/v1" + info.RequestURLPath
|
||||
}
|
||||
if info.BaseUrl == "" {
|
||||
info.BaseUrl = common.ChannelBaseURLs[channelType]
|
||||
info.BaseUrl = constant.ChannelBaseURLs[channelType]
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeAzure {
|
||||
if info.ChannelType == constant.ChannelTypeAzure {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeVertexAi {
|
||||
if info.ChannelType == constant.ChannelTypeVertexAi {
|
||||
info.ApiVersion = c.GetString("region")
|
||||
}
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -15,9 +15,9 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
case constant.ChannelTypeOpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case common.ChannelTypeAzure:
|
||||
case constant.ChannelTypeAzure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/xinference"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -16,17 +17,14 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
if common.DebugEnabled {
|
||||
println("reranker response body: ", string(responseBody))
|
||||
}
|
||||
var jinaResp dto.RerankResponse
|
||||
if info.ChannelType == common.ChannelTypeXinference {
|
||||
if info.ChannelType == constant.ChannelTypeXinference {
|
||||
var xinRerankResponse xinference.XinRerankResponse
|
||||
err = common.DecodeJson(responseBody, &xinRerankResponse)
|
||||
err = common.UnmarshalJson(responseBody, &xinRerankResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
@@ -61,7 +59,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
},
|
||||
}
|
||||
} else {
|
||||
err = common.DecodeJson(responseBody, &jinaResp)
|
||||
err = common.UnmarshalJson(responseBody, &jinaResp)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
package constant
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
const (
|
||||
APITypeOpenAI = iota
|
||||
APITypeAnthropic
|
||||
APITypePaLM
|
||||
APITypeBaidu
|
||||
APITypeZhipu
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
APITypeZhipuV4
|
||||
APITypeOllama
|
||||
APITypePerplexity
|
||||
APITypeAws
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
APITypeSiliconFlow
|
||||
APITypeVertexAi
|
||||
APITypeMistral
|
||||
APITypeDeepSeek
|
||||
APITypeMokaAI
|
||||
APITypeVolcEngine
|
||||
APITypeBaiduV2
|
||||
APITypeOpenRouter
|
||||
APITypeXinference
|
||||
APITypeXai
|
||||
APITypeCoze
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
|
||||
func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType := -1
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
apiType = APITypeOpenAI
|
||||
case common.ChannelTypeAnthropic:
|
||||
apiType = APITypeAnthropic
|
||||
case common.ChannelTypeBaidu:
|
||||
apiType = APITypeBaidu
|
||||
case common.ChannelTypePaLM:
|
||||
apiType = APITypePaLM
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
case common.ChannelTypeAli:
|
||||
apiType = APITypeAli
|
||||
case common.ChannelTypeXunfei:
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
case common.ChannelTypeZhipu_v4:
|
||||
apiType = APITypeZhipuV4
|
||||
case common.ChannelTypeOllama:
|
||||
apiType = APITypeOllama
|
||||
case common.ChannelTypePerplexity:
|
||||
apiType = APITypePerplexity
|
||||
case common.ChannelTypeAws:
|
||||
apiType = APITypeAws
|
||||
case common.ChannelTypeCohere:
|
||||
apiType = APITypeCohere
|
||||
case common.ChannelTypeDify:
|
||||
apiType = APITypeDify
|
||||
case common.ChannelTypeJina:
|
||||
apiType = APITypeJina
|
||||
case common.ChannelCloudflare:
|
||||
apiType = APITypeCloudflare
|
||||
case common.ChannelTypeSiliconFlow:
|
||||
apiType = APITypeSiliconFlow
|
||||
case common.ChannelTypeVertexAi:
|
||||
apiType = APITypeVertexAi
|
||||
case common.ChannelTypeMistral:
|
||||
apiType = APITypeMistral
|
||||
case common.ChannelTypeDeepSeek:
|
||||
apiType = APITypeDeepSeek
|
||||
case common.ChannelTypeMokaAI:
|
||||
apiType = APITypeMokaAI
|
||||
case common.ChannelTypeVolcEngine:
|
||||
apiType = APITypeVolcEngine
|
||||
case common.ChannelTypeBaiduV2:
|
||||
apiType = APITypeBaiduV2
|
||||
case common.ChannelTypeOpenRouter:
|
||||
apiType = APITypeOpenRouter
|
||||
case common.ChannelTypeXinference:
|
||||
apiType = APITypeXinference
|
||||
case common.ChannelTypeXai:
|
||||
apiType = APITypeXai
|
||||
case common.ChannelTypeCoze:
|
||||
apiType = APITypeCoze
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
}
|
||||
return apiType, true
|
||||
}
|
||||
@@ -29,6 +29,8 @@ const (
|
||||
RelayModeMidjourneyShorten
|
||||
RelayModeSwapFace
|
||||
RelayModeMidjourneyUpload
|
||||
RelayModeMidjourneyVideo
|
||||
RelayModeMidjourneyEdits
|
||||
|
||||
RelayModeAudioSpeech // tts
|
||||
RelayModeAudioTranscription // whisper
|
||||
@@ -108,6 +110,10 @@ func Path2RelayModeMidjourney(path string) int {
|
||||
relayMode = RelayModeMidjourneyUpload
|
||||
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
||||
relayMode = RelayModeMidjourneyImagine
|
||||
} else if strings.HasSuffix(path, "/mj/submit/video") {
|
||||
relayMode = RelayModeMidjourneyVideo
|
||||
} else if strings.HasSuffix(path, "/mj/submit/edits") {
|
||||
relayMode = RelayModeMidjourneyEdits
|
||||
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
||||
relayMode = RelayModeMidjourneyBlend
|
||||
} else if strings.HasSuffix(path, "/mj/submit/describe") {
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
||||
DefaultPingInterval = 10 * time.Second
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
scanner = bufio.NewScanner(resp.Body)
|
||||
ticker = time.NewTicker(streamingTimeout)
|
||||
pingTicker *time.Ticker
|
||||
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
||||
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
||||
wg sync.WaitGroup // 用于等待所有 goroutine 退出
|
||||
)
|
||||
|
||||
@@ -64,32 +64,39 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
pingTicker = time.NewTicker(pingInterval)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
// print timeout and ping interval for debugging
|
||||
println("relay timeout seconds:", common.RelayTimeout)
|
||||
println("streaming timeout seconds:", int64(streamingTimeout.Seconds()))
|
||||
println("ping interval seconds:", int64(pingInterval.Seconds()))
|
||||
}
|
||||
|
||||
// 改进资源清理,确保所有 goroutine 正确退出
|
||||
defer func() {
|
||||
// 通知所有 goroutine 停止
|
||||
common.SafeSendBool(stopChan, true)
|
||||
|
||||
|
||||
ticker.Stop()
|
||||
if pingTicker != nil {
|
||||
pingTicker.Stop()
|
||||
}
|
||||
|
||||
|
||||
// 等待所有 goroutine 退出,最多等待5秒
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
common.LogError(c, "timeout waiting for goroutines to exit")
|
||||
}
|
||||
|
||||
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
|
||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
SetEventStreamHeaders(c)
|
||||
@@ -113,12 +120,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
println("ping goroutine exited")
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// 添加超时保护,防止 goroutine 无限运行
|
||||
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
|
||||
pingTimeout := time.NewTimer(maxPingDuration)
|
||||
defer pingTimeout.Stop()
|
||||
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
@@ -129,7 +136,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
defer writeMutex.Unlock()
|
||||
done <- PingData(c)
|
||||
}()
|
||||
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
@@ -175,7 +182,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
println("scanner goroutine exited")
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
for scanner.Scan() {
|
||||
// 检查是否需要停止
|
||||
select {
|
||||
@@ -187,7 +194,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
|
||||
ticker.Reset(streamingTimeout)
|
||||
data := scanner.Text()
|
||||
if common.DebugEnabled {
|
||||
@@ -205,7 +212,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
info.SetFirstResponseTime()
|
||||
|
||||
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan bool, 1)
|
||||
go func() {
|
||||
@@ -213,7 +220,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
defer writeMutex.Unlock()
|
||||
done <- dataHandler(data)
|
||||
}()
|
||||
|
||||
|
||||
select {
|
||||
case success := <-done:
|
||||
if !success {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -17,8 +18,6 @@ import (
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"one-api/relay/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
|
||||
@@ -106,6 +106,9 @@ func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
||||
midjourneyTask.StartTime = midjRequest.StartTime
|
||||
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||||
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||||
midjourneyTask.VideoUrl = midjRequest.VideoUrl
|
||||
videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
|
||||
midjourneyTask.VideoUrls = string(videoUrlsStr)
|
||||
midjourneyTask.Status = midjRequest.Status
|
||||
midjourneyTask.FailReason = midjRequest.FailReason
|
||||
err = midjourneyTask.Update()
|
||||
@@ -136,6 +139,9 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
} else {
|
||||
midjourneyTask.ImageUrl = originTask.ImageUrl
|
||||
}
|
||||
if originTask.VideoUrl != "" {
|
||||
midjourneyTask.VideoUrl = originTask.VideoUrl
|
||||
}
|
||||
midjourneyTask.Status = originTask.Status
|
||||
midjourneyTask.FailReason = originTask.FailReason
|
||||
midjourneyTask.Action = originTask.Action
|
||||
@@ -148,6 +154,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.Buttons = buttons
|
||||
}
|
||||
}
|
||||
if originTask.VideoUrls != "" {
|
||||
var videoUrls []dto.ImgUrls
|
||||
err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
|
||||
if err == nil {
|
||||
midjourneyTask.VideoUrls = videoUrls
|
||||
}
|
||||
}
|
||||
if originTask.Properties != "" {
|
||||
var properties dto.Properties
|
||||
err := json.Unmarshal([]byte(originTask.Properties), &properties)
|
||||
@@ -279,10 +292,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||||
}
|
||||
common.IOCopyBytesGracefully(c, nil, respBody)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -373,6 +383,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
relayMode = relayconstant.RelayModeMidjourneyChange
|
||||
}
|
||||
if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||
midjRequest.Action = constant.MjActionVideo
|
||||
}
|
||||
|
||||
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
@@ -381,6 +394,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
midjRequest.Action = constant.MjActionImagine
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||||
midjRequest.Action = constant.MjActionDescribe
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
||||
midjRequest.Action = constant.MjActionEdits
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||||
midjRequest.Action = constant.MjActionShorten
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
@@ -415,6 +430,14 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
//}
|
||||
mjId = midjRequest.TaskId
|
||||
midjRequest.Action = constant.MjActionModal
|
||||
} else if relayMode == relayconstant.RelayModeMidjourneyVideo {
|
||||
midjRequest.Action = constant.MjActionVideo
|
||||
if midjRequest.TaskId == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||||
} else if midjRequest.Action == "" {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
|
||||
}
|
||||
mjId = midjRequest.TaskId
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"one-api/constant"
|
||||
commonconstant "one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
@@ -32,7 +33,6 @@ import (
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
|
||||
@@ -78,12 +78,15 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
|
||||
}
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
|
||||
@@ -103,6 +103,8 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/edits", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/video", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/operation_setting"
|
||||
@@ -48,7 +49,7 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
||||
}
|
||||
if err.StatusCode == http.StatusForbidden {
|
||||
switch channelType {
|
||||
case common.ChannelTypeGemini:
|
||||
case constant.ChannelTypeGemini:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -19,7 +20,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
|
||||
Stream: claudeRequest.Stream,
|
||||
}
|
||||
|
||||
isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
|
||||
isOpenRouter := info.ChannelType == constant.ChannelTypeOpenRouter
|
||||
|
||||
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
|
||||
if isOpenRouter {
|
||||
|
||||
@@ -90,10 +90,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (errWithStatu
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
var errResponse dto.GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
|
||||
@@ -13,9 +13,8 @@ import (
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
var impatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
func InitHttpClient() {
|
||||
if common.RelayTimeout == 0 {
|
||||
httpClient = &http.Client{}
|
||||
} else {
|
||||
@@ -23,20 +22,12 @@ func init() {
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
impatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func GetHttpClient() *http.Client {
|
||||
return httpClient
|
||||
}
|
||||
|
||||
func GetImpatientHttpClient() *http.Client {
|
||||
return impatientHTTPClient
|
||||
}
|
||||
|
||||
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
|
||||
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
if proxyURL == "" {
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CoverActionToModelName(mjAction string) string {
|
||||
@@ -38,6 +39,10 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeMidjourneyImagine:
|
||||
action = constant.MjActionImagine
|
||||
case relayconstant.RelayModeMidjourneyVideo:
|
||||
action = constant.MjActionVideo
|
||||
case relayconstant.RelayModeMidjourneyEdits:
|
||||
action = constant.MjActionEdits
|
||||
case relayconstant.RelayModeMidjourneyDescribe:
|
||||
action = constant.MjActionDescribe
|
||||
case relayconstant.RelayModeMidjourneyBlend:
|
||||
@@ -228,10 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
respStr := string(responseBody)
|
||||
log.Printf("respStr: %s", respStr)
|
||||
if respStr == "" {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"log"
|
||||
"math"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -232,7 +232,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheCreationRatio := priceData.CacheCreationRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
||||
@@ -447,7 +447,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
gopool.Go(func() {
|
||||
userSetting := relayInfo.UserSetting
|
||||
threshold := common.QuotaRemindThreshold
|
||||
if userCustomThreshold, ok := userSetting[constant2.UserSettingQuotaWarningThreshold]; ok {
|
||||
if userCustomThreshold, ok := userSetting[constant.UserSettingQuotaWarningThreshold]; ok {
|
||||
threshold = int(userCustomThreshold.(float64))
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
|
||||
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
var config image.Config
|
||||
@@ -172,9 +172,6 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
||||
}
|
||||
}
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
}
|
||||
@@ -195,9 +192,6 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += systemTokens
|
||||
}
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
client := GetImpatientHttpClient()
|
||||
client := GetHttpClient()
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send webhook request: %v", err)
|
||||
|
||||
@@ -231,7 +231,9 @@ var defaultModelPrice = map[string]float64{
|
||||
"dall-e-3": 0.04,
|
||||
"imagen-3.0-generate-002": 0.03,
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_video": 0.8,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_edits": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
@@ -509,8 +511,11 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
}
|
||||
return 3.5 / 0.15, false
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") {
|
||||
return 4, true
|
||||
if strings.HasPrefix(name, "gemini-2.5-flash-lite") {
|
||||
if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") {
|
||||
return 4, false
|
||||
}
|
||||
return 4, false
|
||||
}
|
||||
return 2.5 / 0.3, true
|
||||
}
|
||||
|
||||
42
types/set.go
Normal file
42
types/set.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package types
|
||||
|
||||
type Set[T comparable] struct {
|
||||
items map[T]struct{}
|
||||
}
|
||||
|
||||
// NewSet 创建并返回一个新的 Set
|
||||
func NewSet[T comparable]() *Set[T] {
|
||||
return &Set[T]{
|
||||
items: make(map[T]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Set[T]) Add(item T) {
|
||||
s.items[item] = struct{}{}
|
||||
}
|
||||
|
||||
// Remove 从 Set 中移除一个元素
|
||||
func (s *Set[T]) Remove(item T) {
|
||||
delete(s.items, item)
|
||||
}
|
||||
|
||||
// Contains 检查 Set 是否包含某个元素
|
||||
func (s *Set[T]) Contains(item T) bool {
|
||||
_, exists := s.items[item]
|
||||
return exists
|
||||
}
|
||||
|
||||
// Len 返回 Set 中元素的数量
|
||||
func (s *Set[T]) Len() int {
|
||||
return len(s.items)
|
||||
}
|
||||
|
||||
// Items 返回 Set 中所有元素组成的切片
|
||||
// 注意:由于 map 的无序性,返回的切片元素顺序是随机的
|
||||
func (s *Set[T]) Items() []T {
|
||||
items := make([]T, 0, s.Len())
|
||||
for item := range s.items {
|
||||
items = append(items, item)
|
||||
}
|
||||
return items
|
||||
}
|
||||
@@ -399,7 +399,6 @@ const LoginForm = () => {
|
||||
placeholder={t('请输入您的用户名或邮箱地址')}
|
||||
name="username"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('username', value)}
|
||||
prefix={<IconMail />}
|
||||
/>
|
||||
@@ -411,7 +410,6 @@ const LoginForm = () => {
|
||||
name="password"
|
||||
mode="password"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('password', value)}
|
||||
prefix={<IconLock />}
|
||||
/>
|
||||
|
||||
@@ -113,7 +113,6 @@ const PasswordResetConfirm = () => {
|
||||
label={t('邮箱')}
|
||||
name="email"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
disabled={true}
|
||||
prefix={<IconMail />}
|
||||
placeholder={email ? '' : t('等待获取邮箱信息...')}
|
||||
@@ -125,7 +124,6 @@ const PasswordResetConfirm = () => {
|
||||
label={t('新密码')}
|
||||
name="newPassword"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
disabled={true}
|
||||
prefix={<IconLock />}
|
||||
suffix={
|
||||
|
||||
@@ -102,7 +102,6 @@ const PasswordResetForm = () => {
|
||||
placeholder={t('请输入您的邮箱地址')}
|
||||
name="email"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
value={email}
|
||||
onChange={handleChange}
|
||||
prefix={<IconMail />}
|
||||
|
||||
@@ -394,7 +394,6 @@ const RegisterForm = () => {
|
||||
placeholder={t('请输入用户名')}
|
||||
name="username"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('username', value)}
|
||||
prefix={<IconUser />}
|
||||
/>
|
||||
@@ -406,7 +405,6 @@ const RegisterForm = () => {
|
||||
name="password"
|
||||
mode="password"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('password', value)}
|
||||
prefix={<IconLock />}
|
||||
/>
|
||||
@@ -418,7 +416,6 @@ const RegisterForm = () => {
|
||||
name="password2"
|
||||
mode="password"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('password2', value)}
|
||||
prefix={<IconLock />}
|
||||
/>
|
||||
@@ -432,7 +429,6 @@ const RegisterForm = () => {
|
||||
name="email"
|
||||
type="email"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('email', value)}
|
||||
prefix={<IconMail />}
|
||||
suffix={
|
||||
@@ -440,7 +436,6 @@ const RegisterForm = () => {
|
||||
onClick={sendVerificationCode}
|
||||
loading={verificationCodeLoading}
|
||||
size="small"
|
||||
className="!rounded-md mr-2"
|
||||
>
|
||||
{t('获取验证码')}
|
||||
</Button>
|
||||
@@ -452,7 +447,6 @@ const RegisterForm = () => {
|
||||
placeholder={t('输入验证码')}
|
||||
name="verification_code"
|
||||
size="large"
|
||||
className="!rounded-md"
|
||||
onChange={(value) => handleChange('verification_code', value)}
|
||||
prefix={<IconKey />}
|
||||
/>
|
||||
|
||||
@@ -170,8 +170,8 @@ const NoticeModal = ({ visible, onClose, isMobile, defaultTab = 'inApp', unreadK
|
||||
onCancel={onClose}
|
||||
footer={(
|
||||
<div className="flex justify-end">
|
||||
<Button type='secondary' className='!rounded-full' onClick={handleCloseTodayNotice}>{t('今日关闭')}</Button>
|
||||
<Button type="primary" className='!rounded-full' onClick={onClose}>{t('关闭公告')}</Button>
|
||||
<Button type='secondary' onClick={handleCloseTodayNotice}>{t('今日关闭')}</Button>
|
||||
<Button type="primary" onClick={onClose}>{t('关闭公告')}</Button>
|
||||
</div>
|
||||
)}
|
||||
size={isMobile ? 'full-width' : 'large'}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user