mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-07 21:47:26 +00:00
Compare commits
114 Commits
coderabbit
...
v0.10.6-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6169f46cc6 | ||
|
|
d73a0440b9 | ||
|
|
688280b3c3 | ||
|
|
41da848c56 | ||
|
|
22b9438588 | ||
|
|
4ed4a7659a | ||
|
|
138fcd2327 | ||
|
|
62b796fa6a | ||
|
|
1a5c8f3c35 | ||
|
|
e1d43a820f | ||
|
|
c975b4cfa8 | ||
|
|
b51d1e2f78 | ||
|
|
07e77b3c6f | ||
|
|
c2464fc877 | ||
|
|
93012631d1 | ||
|
|
1ab0c35540 | ||
|
|
022fbcee04 | ||
|
|
aed1900364 | ||
|
|
f8938a8f78 | ||
|
|
e13459f3a1 | ||
|
|
ad61c0f89e | ||
|
|
9addf1b705 | ||
|
|
9e61338a6f | ||
|
|
d3f33932c0 | ||
|
|
a8f7c0614f | ||
|
|
5f37a1e97c | ||
|
|
177553af37 | ||
|
|
5ed4583c0c | ||
|
|
1519e97bc6 | ||
|
|
443b05821f | ||
|
|
6b1f8b84c6 | ||
|
|
22d0b73d21 | ||
|
|
e8aaed440c | ||
|
|
cfc72d6817 | ||
|
|
67ba913b44 | ||
|
|
1f78c6a0f9 | ||
|
|
9cf756fc4b | ||
|
|
c33ac97c71 | ||
|
|
c682e41338 | ||
|
|
817da8d73c | ||
|
|
43c671b8b3 | ||
|
|
3510b3e6fc | ||
|
|
f1ba624df2 | ||
|
|
2b8cbbe5ae | ||
|
|
fdd1ac0f6e | ||
|
|
e5679109d4 | ||
|
|
be8e644546 | ||
|
|
a0328b2f5e | ||
|
|
3402d09ed0 | ||
|
|
8abfbe372f | ||
|
|
a195e88896 | ||
|
|
d06915c30d | ||
|
|
b1bb64ae11 | ||
|
|
b2d8ad7883 | ||
|
|
ddb40b1a6e | ||
|
|
8b790446ce | ||
|
|
b808b96cce | ||
|
|
23a68137ad | ||
|
|
2a5b2add9a | ||
|
|
11922ef651 | ||
|
|
d474ed4778 | ||
|
|
ab81d6e444 | ||
|
|
ae2ca945f3 | ||
|
|
04ea79c429 | ||
|
|
48d358faec | ||
|
|
8063897998 | ||
|
|
923dfbeecb | ||
|
|
24d359cf40 | ||
|
|
725d61c5d3 | ||
|
|
1a69a93d20 | ||
|
|
1de78f8749 | ||
|
|
37a1882798 | ||
|
|
edbd5346e4 | ||
|
|
2c2dfea60f | ||
|
|
9aeef6abec | ||
|
|
58db72d459 | ||
|
|
654bb10b45 | ||
|
|
f51b5bb0c8 | ||
|
|
a4cd84f276 | ||
|
|
c722ddd58b | ||
|
|
88e394a976 | ||
|
|
31a3487139 | ||
|
|
a07406d97e | ||
|
|
f68858121c | ||
|
|
83fbaba768 | ||
|
|
d3c854fbed | ||
|
|
97b02685b1 | ||
|
|
da1b51ac31 | ||
|
|
f17b3810d6 | ||
|
|
8206084a77 | ||
|
|
559da6362a | ||
|
|
0b1a562df9 | ||
|
|
a0c3d37d66 | ||
|
|
347f2326f3 | ||
|
|
14c58aea77 | ||
|
|
09f3957362 | ||
|
|
31a79620ba | ||
|
|
12555a37d3 | ||
|
|
3652dfdbd5 | ||
|
|
dbaba87c39 | ||
|
|
f04ed7584a | ||
|
|
0a2f12c04e | ||
|
|
531dfb2555 | ||
|
|
5ef7247eac | ||
|
|
1168ddf9f9 | ||
|
|
da24a165d0 | ||
|
|
f88fc26150 | ||
|
|
2a511c6ee4 | ||
|
|
0217ed2f98 | ||
|
|
fcafadc6bb | ||
|
|
8e629a2a11 | ||
|
|
2a16c37aab | ||
|
|
681b37d104 | ||
|
|
35538ecb3b |
@@ -6,4 +6,5 @@
|
||||
Makefile
|
||||
docs
|
||||
.eslintcache
|
||||
.gocache
|
||||
.gocache
|
||||
/web/node_modules
|
||||
@@ -9,6 +9,14 @@
|
||||
# ENABLE_PPROF=true
|
||||
# 启用调试模式
|
||||
# DEBUG=true
|
||||
# Pyroscope 配置
|
||||
# PYROSCOPE_URL=http://localhost:4040
|
||||
# PYROSCOPE_APP_NAME=new-api
|
||||
# PYROSCOPE_BASIC_AUTH_USER=your-user
|
||||
# PYROSCOPE_BASIC_AUTH_PASSWORD=your-password
|
||||
# PYROSCOPE_MUTEX_RATE=5
|
||||
# PYROSCOPE_BLOCK_RATE=5
|
||||
# HOSTNAME=your-hostname
|
||||
|
||||
# 数据库相关配置
|
||||
# 数据库连接字符串
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -19,7 +19,11 @@ tiktoken_cache
|
||||
.gomodcache/
|
||||
.cache
|
||||
web/bun.lock
|
||||
plans
|
||||
|
||||
electron/node_modules
|
||||
electron/dist
|
||||
data/
|
||||
.gomodcache/
|
||||
.gocache-temp
|
||||
.gopath
|
||||
|
||||
15
README.en.md
15
README.en.md
@@ -213,9 +213,11 @@ docker run --name new-api -d --restart always \
|
||||
- 🚦 User-level model rate limiting
|
||||
|
||||
**Format Conversion:**
|
||||
- 🔄 OpenAI ⇄ Claude Messages
|
||||
- 🔄 OpenAI ⇄ Gemini Chat
|
||||
- 🔄 Thinking-to-content functionality
|
||||
- 🔄 **OpenAI Compatible ⇄ Claude Messages**
|
||||
- 🔄 **OpenAI Compatible → Google Gemini**
|
||||
- 🔄 **Google Gemini → OpenAI Compatible** - Text only, function calling not supported yet
|
||||
- 🚧 **OpenAI Compatible ⇄ OpenAI Responses** - In development
|
||||
- 🔄 **Thinking-to-content functionality**
|
||||
|
||||
**Reasoning Effort Support:**
|
||||
|
||||
@@ -308,6 +310,13 @@ docker run --name new-api -d --restart always \
|
||||
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure API version | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | Error log switch | `false` |
|
||||
| `PYROSCOPE_URL` | Pyroscope server address | - |
|
||||
| `PYROSCOPE_APP_NAME` | Pyroscope application name | `new-api` |
|
||||
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope basic auth user | - |
|
||||
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope basic auth password | - |
|
||||
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutex sampling rate | `5` |
|
||||
| `PYROSCOPE_BLOCK_RATE` | Pyroscope block sampling rate | `5` |
|
||||
| `HOSTNAME` | Hostname tag for Pyroscope | `new-api` |
|
||||
|
||||
📖 **Complete configuration:** [Environment Variables Documentation](https://docs.newapi.pro/en/docs/installation/config-maintenance/environment-variables)
|
||||
|
||||
|
||||
15
README.fr.md
15
README.fr.md
@@ -212,9 +212,11 @@ docker run --name new-api -d --restart always \
|
||||
- 🚦 Limitation du débit du modèle pour les utilisateurs
|
||||
|
||||
**Conversion de format:**
|
||||
- 🔄 OpenAI ⇄ Claude Messages
|
||||
- 🔄 OpenAI ⇄ Gemini Chat
|
||||
- 🔄 Fonctionnalité de la pensée au contenu
|
||||
- 🔄 **OpenAI Compatible ⇄ Claude Messages**
|
||||
- 🔄 **OpenAI Compatible → Google Gemini**
|
||||
- 🔄 **Google Gemini → OpenAI Compatible** - Texte uniquement, les appels de fonction ne sont pas encore pris en charge
|
||||
- 🚧 **OpenAI Compatible ⇄ OpenAI Responses** - En développement
|
||||
- 🔄 **Fonctionnalité de la pensée au contenu**
|
||||
|
||||
**Prise en charge de l'effort de raisonnement:**
|
||||
|
||||
@@ -304,6 +306,13 @@ docker run --name new-api -d --restart always \
|
||||
| `MAX_REQUEST_BODY_MB` | Taille maximale du corps de requête (Mo, comptée **après décompression** ; évite les requêtes énormes/zip bombs qui saturent la mémoire). Dépassement ⇒ `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Version de l'API Azure | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | Interrupteur du journal d'erreurs | `false` |
|
||||
| `PYROSCOPE_URL` | Adresse du serveur Pyroscope | - |
|
||||
| `PYROSCOPE_APP_NAME` | Nom de l'application Pyroscope | `new-api` |
|
||||
| `PYROSCOPE_BASIC_AUTH_USER` | Utilisateur Basic Auth Pyroscope | - |
|
||||
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Mot de passe Basic Auth Pyroscope | - |
|
||||
| `PYROSCOPE_MUTEX_RATE` | Taux d'échantillonnage mutex Pyroscope | `5` |
|
||||
| `PYROSCOPE_BLOCK_RATE` | Taux d'échantillonnage block Pyroscope | `5` |
|
||||
| `HOSTNAME` | Nom d'hôte tagué pour Pyroscope | `new-api` |
|
||||
|
||||
📖 **Configuration complète:** [Documentation des variables d'environnement](https://docs.newapi.pro/en/docs/installation/config-maintenance/environment-variables)
|
||||
|
||||
|
||||
15
README.ja.md
15
README.ja.md
@@ -218,9 +218,11 @@ docker run --name new-api -d --restart always \
|
||||
- 🚦 ユーザーレベルモデルレート制限
|
||||
|
||||
**フォーマット変換:**
|
||||
- 🔄 OpenAI ⇄ Claude Messages
|
||||
- 🔄 OpenAI ⇄ Gemini Chat
|
||||
- 🔄 思考からコンテンツへの機能
|
||||
- 🔄 **OpenAI Compatible ⇄ Claude Messages**
|
||||
- 🔄 **OpenAI Compatible → Google Gemini**
|
||||
- 🔄 **Google Gemini → OpenAI Compatible** - テキストのみ、関数呼び出しはまだサポートされていません
|
||||
- 🚧 **OpenAI Compatible ⇄ OpenAI Responses** - 開発中
|
||||
- 🔄 **思考からコンテンツへの機能**
|
||||
|
||||
**Reasoning Effort サポート:**
|
||||
|
||||
@@ -313,6 +315,13 @@ docker run --name new-api -d --restart always \
|
||||
| `MAX_REQUEST_BODY_MB` | リクエストボディ最大サイズ(MB、**解凍後**に計測。巨大リクエスト/zip bomb によるメモリ枯渇を防止)。超過時は `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure APIバージョン | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | エラーログスイッチ | `false` |
|
||||
| `PYROSCOPE_URL` | Pyroscopeサーバーのアドレス | - |
|
||||
| `PYROSCOPE_APP_NAME` | Pyroscopeアプリ名 | `new-api` |
|
||||
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope Basic Authユーザー | - |
|
||||
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope Basic Authパスワード | - |
|
||||
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutexサンプリング率 | `5` |
|
||||
| `PYROSCOPE_BLOCK_RATE` | Pyroscope blockサンプリング率 | `5` |
|
||||
| `HOSTNAME` | Pyroscope用のホスト名タグ | `new-api` |
|
||||
|
||||
📖 **完全な設定:** [環境変数ドキュメント](https://docs.newapi.pro/ja/docs/installation/config-maintenance/environment-variables)
|
||||
|
||||
|
||||
15
README.md
15
README.md
@@ -214,9 +214,11 @@ docker run --name new-api -d --restart always \
|
||||
- 🚦 用户级别模型限流
|
||||
|
||||
**格式转换:**
|
||||
- 🔄 OpenAI ⇄ Claude Messages
|
||||
- 🔄 OpenAI ⇄ Gemini Chat
|
||||
- 🔄 思考转内容功能
|
||||
- 🔄 **OpenAI Compatible ⇄ Claude Messages**
|
||||
- 🔄 **OpenAI Compatible → Google Gemini**
|
||||
- 🔄 **Google Gemini → OpenAI Compatible** - 仅支持文本,暂不支持函数调用
|
||||
- 🚧 **OpenAI Compatible ⇄ OpenAI Responses** - 开发中
|
||||
- 🔄 **思考转内容功能**
|
||||
|
||||
**Reasoning Effort 支持:**
|
||||
|
||||
@@ -309,6 +311,13 @@ docker run --name new-api -d --restart always \
|
||||
| `MAX_REQUEST_BODY_MB` | 请求体最大大小(MB,**解压后**计;防止超大请求/zip bomb 导致内存暴涨),超过将返回 `413` | `32` |
|
||||
| `AZURE_DEFAULT_API_VERSION` | Azure API 版本 | `2025-04-01-preview` |
|
||||
| `ERROR_LOG_ENABLED` | 错误日志开关 | `false` |
|
||||
| `PYROSCOPE_URL` | Pyroscope 服务地址 | - |
|
||||
| `PYROSCOPE_APP_NAME` | Pyroscope 应用名 | `new-api` |
|
||||
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope Basic Auth 用户名 | - |
|
||||
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope Basic Auth 密码 | - |
|
||||
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutex 采样率 | `5` |
|
||||
| `PYROSCOPE_BLOCK_RATE` | Pyroscope block 采样率 | `5` |
|
||||
| `HOSTNAME` | Pyroscope 标签里的主机名 | `new-api` |
|
||||
|
||||
📖 **完整配置:** [环境变量文档](https://docs.newapi.pro/zh/docs/installation/config-maintenance/environment-variables)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
maxMB := constant.MaxRequestBodyMB
|
||||
if maxMB < 0 {
|
||||
if maxMB <= 0 {
|
||||
// no limit
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
_ = c.Request.Body.Close()
|
||||
|
||||
@@ -115,10 +115,10 @@ func InitEnv() {
|
||||
func initConstantEnv() {
|
||||
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
|
||||
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 64)
|
||||
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
|
||||
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 64)
|
||||
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 128)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
|
||||
56
common/pyro.go
Normal file
56
common/pyro.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/grafana/pyroscope-go"
|
||||
)
|
||||
|
||||
func StartPyroScope() error {
|
||||
|
||||
pyroscopeUrl := GetEnvOrDefaultString("PYROSCOPE_URL", "")
|
||||
if pyroscopeUrl == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
pyroscopeAppName := GetEnvOrDefaultString("PYROSCOPE_APP_NAME", "new-api")
|
||||
pyroscopeBasicAuthUser := GetEnvOrDefaultString("PYROSCOPE_BASIC_AUTH_USER", "")
|
||||
pyroscopeBasicAuthPassword := GetEnvOrDefaultString("PYROSCOPE_BASIC_AUTH_PASSWORD", "")
|
||||
pyroscopeHostname := GetEnvOrDefaultString("HOSTNAME", "new-api")
|
||||
|
||||
mutexRate := GetEnvOrDefault("PYROSCOPE_MUTEX_RATE", 5)
|
||||
blockRate := GetEnvOrDefault("PYROSCOPE_BLOCK_RATE", 5)
|
||||
|
||||
runtime.SetMutexProfileFraction(mutexRate)
|
||||
runtime.SetBlockProfileRate(blockRate)
|
||||
|
||||
_, err := pyroscope.Start(pyroscope.Config{
|
||||
ApplicationName: pyroscopeAppName,
|
||||
|
||||
ServerAddress: pyroscopeUrl,
|
||||
BasicAuthUser: pyroscopeBasicAuthUser,
|
||||
BasicAuthPassword: pyroscopeBasicAuthPassword,
|
||||
|
||||
Logger: nil,
|
||||
|
||||
Tags: map[string]string{"hostname": pyroscopeHostname},
|
||||
|
||||
ProfileTypes: []pyroscope.ProfileType{
|
||||
pyroscope.ProfileCPU,
|
||||
pyroscope.ProfileAllocObjects,
|
||||
pyroscope.ProfileAllocSpace,
|
||||
pyroscope.ProfileInuseObjects,
|
||||
pyroscope.ProfileInuseSpace,
|
||||
|
||||
pyroscope.ProfileGoroutines,
|
||||
pyroscope.ProfileMutexCount,
|
||||
pyroscope.ProfileMutexDuration,
|
||||
pyroscope.ProfileBlockCount,
|
||||
pyroscope.ProfileBlockDuration,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -16,6 +16,8 @@ var (
|
||||
maskURLPattern = regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
|
||||
maskDomainPattern = regexp.MustCompile(`\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}\b`)
|
||||
maskIPPattern = regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
|
||||
// maskApiKeyPattern matches patterns like 'api_key:xxx' or "api_key:xxx" to mask the API key value
|
||||
maskApiKeyPattern = regexp.MustCompile(`(['"]?)api_key:([^\s'"]+)(['"]?)`)
|
||||
)
|
||||
|
||||
func GetStringIfEmpty(str string, defaultValue string) string {
|
||||
@@ -235,5 +237,8 @@ func MaskSensitiveInfo(str string) string {
|
||||
// Mask IP addresses
|
||||
str = maskIPPattern.ReplaceAllString(str, "***.***.***.***")
|
||||
|
||||
// Mask API keys (e.g., "api_key:AIzaSyAAAaUooTUni8AdaOkSRMda30n_Q4vrV70" -> "api_key:***")
|
||||
str = maskApiKeyPattern.ReplaceAllString(str, "${1}api_key:***${3}")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
@@ -40,13 +40,6 @@ type testResult struct {
|
||||
newAPIError *types.NewAPIError
|
||||
}
|
||||
|
||||
// testChannel executes a test request against the given channel using the provided testModel and optional endpointType,
|
||||
// and returns a testResult containing the test context and any encountered error information.
|
||||
// It selects or derives a model when testModel is empty, auto-detects the request endpoint (chat, responses, embeddings, images, rerank) when endpointType is not specified,
|
||||
// converts and relays the request to the upstream adapter, and parses the upstream response to collect usage and pricing information.
|
||||
// On upstream responses that indicate the chat/completions `messages` parameter is unsupported and endpointType was not specified, it will retry the test using the Responses API.
|
||||
// The function records consumption logs and returns a testResult with a populated context on success, or with localErr/newAPIError set on failure;
|
||||
// for channel types that are not supported for testing it returns a localErr explaining that the channel test is not supported.
|
||||
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
|
||||
tik := time.Now()
|
||||
var unsupportedTestChannelTypes = []int{
|
||||
@@ -82,8 +75,6 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
originTestModel := testModel
|
||||
|
||||
requestPath := "/v1/chat/completions"
|
||||
|
||||
// 如果指定了端点类型,使用指定的端点类型
|
||||
@@ -93,10 +84,6 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
} else {
|
||||
// 如果没有指定端点类型,使用原有的自动检测逻辑
|
||||
if common.IsOpenAIResponseOnlyModel(testModel) {
|
||||
requestPath = "/v1/responses"
|
||||
}
|
||||
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
||||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
|
||||
@@ -110,6 +97,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
||||
requestPath = "/v1/images/generations"
|
||||
}
|
||||
|
||||
// responses-only models
|
||||
if strings.Contains(strings.ToLower(testModel), "codex") {
|
||||
requestPath = "/v1/responses"
|
||||
}
|
||||
}
|
||||
|
||||
c.Request = &http.Request{
|
||||
@@ -189,7 +181,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
request := buildTestRequest(testModel, endpointType)
|
||||
request := buildTestRequest(testModel, endpointType, channel)
|
||||
|
||||
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
||||
|
||||
@@ -201,6 +193,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
info.IsChannelTest = true
|
||||
info.InitChannelMeta(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info, request)
|
||||
@@ -317,6 +310,27 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
||||
}
|
||||
}
|
||||
|
||||
//jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
//if err != nil {
|
||||
// return testResult{
|
||||
// context: c,
|
||||
// localErr: err,
|
||||
// newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
||||
// }
|
||||
//}
|
||||
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
newAPIError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
@@ -332,13 +346,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
||||
// 自动检测模式下,如果上游不支持 chat.completions 的 messages 参数,尝试切换到 Responses API 再测一次。
|
||||
if endpointType == "" && requestPath == "/v1/chat/completions" && err != nil {
|
||||
lowerErr := strings.ToLower(err.Error())
|
||||
if strings.Contains(lowerErr, "unsupported parameter") && strings.Contains(lowerErr, "messages") {
|
||||
return testChannel(channel, originTestModel, string(constant.EndpointTypeOpenAIResponse))
|
||||
}
|
||||
}
|
||||
common.SysError(fmt.Sprintf(
|
||||
"channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v",
|
||||
channel.Id,
|
||||
channel.Name,
|
||||
channel.Type,
|
||||
testModel,
|
||||
endpointType,
|
||||
httpResp.StatusCode,
|
||||
err,
|
||||
))
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
@@ -409,8 +426,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
|
||||
}
|
||||
}
|
||||
|
||||
// for embedding models, and otherwise a chat/completion request with model-specific token limit heuristics.
|
||||
func buildTestRequest(model string, endpointType string) dto.Request {
|
||||
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
|
||||
// 根据端点类型构建不同的测试请求
|
||||
if endpointType != "" {
|
||||
switch constant.EndpointType(endpointType) {
|
||||
@@ -438,16 +454,13 @@ func buildTestRequest(model string, endpointType string) dto.Request {
|
||||
}
|
||||
case constant.EndpointTypeOpenAIResponse:
|
||||
// 返回 OpenAIResponsesRequest
|
||||
maxOutputTokens := uint(10)
|
||||
return &dto.OpenAIResponsesRequest{
|
||||
Model: model,
|
||||
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
||||
MaxOutputTokens: maxOutputTokens,
|
||||
Stream: true,
|
||||
Model: model,
|
||||
Input: json.RawMessage("\"hi\""),
|
||||
}
|
||||
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
|
||||
// 返回 GeneralOpenAIRequest
|
||||
maxTokens := uint(10)
|
||||
maxTokens := uint(16)
|
||||
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
|
||||
maxTokens = 3000
|
||||
}
|
||||
@@ -466,16 +479,6 @@ func buildTestRequest(model string, endpointType string) dto.Request {
|
||||
}
|
||||
|
||||
// 自动检测逻辑(保持原有行为)
|
||||
if common.IsOpenAIResponseOnlyModel(model) {
|
||||
maxOutputTokens := uint(10)
|
||||
return &dto.OpenAIResponsesRequest{
|
||||
Model: model,
|
||||
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
||||
MaxOutputTokens: maxOutputTokens,
|
||||
Stream: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 先判断是否为 Embedding 模型
|
||||
if strings.Contains(strings.ToLower(model), "embedding") ||
|
||||
strings.HasPrefix(model, "m3e") ||
|
||||
@@ -487,6 +490,14 @@ func buildTestRequest(model string, endpointType string) dto.Request {
|
||||
}
|
||||
}
|
||||
|
||||
// Responses-only models (e.g. codex series)
|
||||
if strings.Contains(strings.ToLower(model), "codex") {
|
||||
return &dto.OpenAIResponsesRequest{
|
||||
Model: model,
|
||||
Input: json.RawMessage("\"hi\""),
|
||||
}
|
||||
}
|
||||
|
||||
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: model,
|
||||
@@ -500,7 +511,7 @@ func buildTestRequest(model string, endpointType string) dto.Request {
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "o") {
|
||||
testRequest.MaxCompletionTokens = 10
|
||||
testRequest.MaxCompletionTokens = 16
|
||||
} else if strings.Contains(model, "thinking") {
|
||||
if !strings.Contains(model, "claude") {
|
||||
testRequest.MaxTokens = 50
|
||||
@@ -508,7 +519,7 @@ func buildTestRequest(model string, endpointType string) dto.Request {
|
||||
} else if strings.Contains(model, "gemini") {
|
||||
testRequest.MaxTokens = 3000
|
||||
} else {
|
||||
testRequest.MaxTokens = 10
|
||||
testRequest.MaxTokens = 16
|
||||
}
|
||||
|
||||
return testRequest
|
||||
@@ -674,4 +685,4 @@ func AutomaticallyTestChannels() {
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,16 +11,19 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel/gemini"
|
||||
"github.com/QuantumNous/new-api/relay/channel/ollama"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Permission []struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
@@ -207,11 +210,88 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
// 对于 Ollama 渠道,使用特殊处理
|
||||
if channel.Type == constant.ChannelTypeOllama {
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
models, err := ollama.FetchOllamaModels(baseURL, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
result := OpenAIModelsResponse{
|
||||
Data: make([]OpenAIModel, 0, len(models)),
|
||||
}
|
||||
|
||||
for _, modelInfo := range models {
|
||||
metadata := map[string]any{}
|
||||
if modelInfo.Size > 0 {
|
||||
metadata["size"] = modelInfo.Size
|
||||
}
|
||||
if modelInfo.Digest != "" {
|
||||
metadata["digest"] = modelInfo.Digest
|
||||
}
|
||||
if modelInfo.ModifiedAt != "" {
|
||||
metadata["modified_at"] = modelInfo.ModifiedAt
|
||||
}
|
||||
details := modelInfo.Details
|
||||
if details.ParentModel != "" || details.Format != "" || details.Family != "" || len(details.Families) > 0 || details.ParameterSize != "" || details.QuantizationLevel != "" {
|
||||
metadata["details"] = modelInfo.Details
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
metadata = nil
|
||||
}
|
||||
|
||||
result.Data = append(result.Data, OpenAIModel{
|
||||
ID: modelInfo.Name,
|
||||
Object: "model",
|
||||
Created: 0,
|
||||
OwnedBy: "ollama",
|
||||
Metadata: metadata,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": result.Data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 对于 Gemini 渠道,使用特殊处理
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
// 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
|
||||
key, _, apiErr := channel.GetNextEnabledKey()
|
||||
if apiErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var url string
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||
case constant.ChannelTypeAli:
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
@@ -917,9 +997,6 @@ func UpdateChannel(c *gin.Context) {
|
||||
// 单个JSON密钥
|
||||
newKeys = []string{channel.Key}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
} else {
|
||||
// 普通渠道的处理
|
||||
inputKeys := strings.Split(channel.Key, "\n")
|
||||
@@ -929,10 +1006,31 @@ func UpdateChannel(c *gin.Context) {
|
||||
newKeys = append(newKeys, key)
|
||||
}
|
||||
}
|
||||
// 合并密钥
|
||||
allKeys := append(existingKeys, newKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(existingKeys)+len(newKeys))
|
||||
for _, key := range existingKeys {
|
||||
normalized := strings.TrimSpace(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
}
|
||||
dedupedNewKeys := make([]string, 0, len(newKeys))
|
||||
for _, key := range newKeys {
|
||||
normalized := strings.TrimSpace(key)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
seen[normalized] = struct{}{}
|
||||
dedupedNewKeys = append(dedupedNewKeys, normalized)
|
||||
}
|
||||
|
||||
allKeys := append(existingKeys, dedupedNewKeys...)
|
||||
channel.Key = strings.Join(allKeys, "\n")
|
||||
}
|
||||
case "replace":
|
||||
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
|
||||
@@ -975,6 +1073,49 @@ func FetchModels(c *gin.Context) {
|
||||
baseURL = constant.ChannelBaseURLs[req.Type]
|
||||
}
|
||||
|
||||
// remove line breaks and extra spaces.
|
||||
key := strings.TrimSpace(req.Key)
|
||||
key = strings.Split(key, "\n")[0]
|
||||
|
||||
if req.Type == constant.ChannelTypeOllama {
|
||||
models, err := ollama.FetchOllamaModels(baseURL, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(models))
|
||||
for _, modelInfo := range models {
|
||||
names = append(names, modelInfo.Name)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": names,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Type == constant.ChannelTypeGemini {
|
||||
models, err := gemini.FetchGeminiModels(baseURL, key, "")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
@@ -987,10 +1128,6 @@ func FetchModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// remove line breaks and extra spaces.
|
||||
key := strings.TrimSpace(req.Key)
|
||||
// If the key contains a line break, only take the first part.
|
||||
key = strings.Split(key, "\n")[0]
|
||||
request.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
response, err := client.Do(request)
|
||||
@@ -1640,3 +1777,262 @@ func ManageMultiKeys(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// OllamaPullModel 拉取 Ollama 模型
|
||||
func OllamaPullModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ChannelID int `json:"channel_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ChannelID == 0 || req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel ID and model name are required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取渠道信息
|
||||
channel, err := model.GetChannelById(req.ChannelID, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否是 Ollama 渠道
|
||||
if channel.Type != constant.ChannelTypeOllama {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "This operation is only supported for Ollama channels",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
|
||||
})
|
||||
}
|
||||
|
||||
// OllamaPullModelStream 流式拉取 Ollama 模型
|
||||
func OllamaPullModelStream(c *gin.Context) {
|
||||
var req struct {
|
||||
ChannelID int `json:"channel_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ChannelID == 0 || req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel ID and model name are required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取渠道信息
|
||||
channel, err := model.GetChannelById(req.ChannelID, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否是 Ollama 渠道
|
||||
if channel.Type != constant.ChannelTypeOllama {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "This operation is only supported for Ollama channels",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
// 设置 SSE 头部
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
|
||||
// 创建进度回调函数
|
||||
progressCallback := func(progress ollama.OllamaPullResponse) {
|
||||
data, _ := json.Marshal(progress)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// 执行拉取
|
||||
err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
|
||||
|
||||
if err != nil {
|
||||
errorData, _ := json.Marshal(gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
|
||||
} else {
|
||||
successData, _ := json.Marshal(gin.H{
|
||||
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
|
||||
})
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
|
||||
}
|
||||
|
||||
// 发送结束标志
|
||||
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
// OllamaDeleteModel 删除 Ollama 模型
|
||||
func OllamaDeleteModel(c *gin.Context) {
|
||||
var req struct {
|
||||
ChannelID int `json:"channel_id"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request parameters",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if req.ChannelID == 0 || req.ModelName == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel ID and model name are required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取渠道信息
|
||||
channel, err := model.GetChannelById(req.ChannelID, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否是 Ollama 渠道
|
||||
if channel.Type != constant.ChannelTypeOllama {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "This operation is only supported for Ollama channels",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
|
||||
})
|
||||
}
|
||||
|
||||
// OllamaVersion 获取 Ollama 服务版本信息
|
||||
func OllamaVersion(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid channel id",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Channel not found",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if channel.Type != constant.ChannelTypeOllama {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "This operation is only supported for Ollama channels",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := constant.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
version, err := ollama.FetchOllamaVersion(baseURL, key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"version": version,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
72
controller/checkin.go
Normal file
72
controller/checkin.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetCheckinStatus 获取用户签到状态和历史记录
|
||||
func GetCheckinStatus(c *gin.Context) {
|
||||
setting := operation_setting.GetCheckinSetting()
|
||||
if !setting.Enabled {
|
||||
common.ApiErrorMsg(c, "签到功能未启用")
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
// 获取月份参数,默认为当前月份
|
||||
month := c.DefaultQuery("month", time.Now().Format("2006-01"))
|
||||
|
||||
stats, err := model.GetUserCheckinStats(userId, month)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"enabled": setting.Enabled,
|
||||
"min_quota": setting.MinQuota,
|
||||
"max_quota": setting.MaxQuota,
|
||||
"stats": stats,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// DoCheckin 执行用户签到
|
||||
func DoCheckin(c *gin.Context) {
|
||||
setting := operation_setting.GetCheckinSetting()
|
||||
if !setting.Enabled {
|
||||
common.ApiErrorMsg(c, "签到功能未启用")
|
||||
return
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
|
||||
checkin, err := model.UserCheckin(userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("用户签到,获得额度 %s", logger.LogQuota(checkin.QuotaAwarded)))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "签到成功",
|
||||
"data": gin.H{
|
||||
"quota_awarded": checkin.QuotaAwarded,
|
||||
"checkin_date": checkin.CheckinDate},
|
||||
})
|
||||
}
|
||||
810
controller/deployment.go
Normal file
810
controller/deployment.go
Normal file
@@ -0,0 +1,810 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/pkg/ionet"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getIoAPIKey(c *gin.Context) (string, bool) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
enabled := common.OptionMap["model_deployment.ionet.enabled"] == "true"
|
||||
apiKey := common.OptionMap["model_deployment.ionet.api_key"]
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
if !enabled || strings.TrimSpace(apiKey) == "" {
|
||||
common.ApiErrorMsg(c, "io.net model deployment is not enabled or api key missing")
|
||||
return "", false
|
||||
}
|
||||
return apiKey, true
|
||||
}
|
||||
|
||||
func GetModelDeploymentSettings(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
enabled := common.OptionMap["model_deployment.ionet.enabled"] == "true"
|
||||
hasAPIKey := strings.TrimSpace(common.OptionMap["model_deployment.ionet.api_key"]) != ""
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"provider": "io.net",
|
||||
"enabled": enabled,
|
||||
"configured": hasAPIKey,
|
||||
"can_connect": enabled && hasAPIKey,
|
||||
})
|
||||
}
|
||||
|
||||
func getIoClient(c *gin.Context) (*ionet.Client, bool) {
|
||||
apiKey, ok := getIoAPIKey(c)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return ionet.NewClient(apiKey), true
|
||||
}
|
||||
|
||||
func getIoEnterpriseClient(c *gin.Context) (*ionet.Client, bool) {
|
||||
apiKey, ok := getIoAPIKey(c)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return ionet.NewEnterpriseClient(apiKey), true
|
||||
}
|
||||
|
||||
func TestIoNetConnection(c *gin.Context) {
|
||||
var req struct {
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
rawBody, err := c.GetRawData()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if len(bytes.TrimSpace(rawBody)) > 0 {
|
||||
if err := json.Unmarshal(rawBody, &req); err != nil {
|
||||
common.ApiErrorMsg(c, "invalid request payload")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
apiKey := strings.TrimSpace(req.APIKey)
|
||||
if apiKey == "" {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
storedKey := strings.TrimSpace(common.OptionMap["model_deployment.ionet.api_key"])
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
if storedKey == "" {
|
||||
common.ApiErrorMsg(c, "api_key is required")
|
||||
return
|
||||
}
|
||||
apiKey = storedKey
|
||||
}
|
||||
|
||||
client := ionet.NewEnterpriseClient(apiKey)
|
||||
result, err := client.GetMaxGPUsPerContainer()
|
||||
if err != nil {
|
||||
if apiErr, ok := err.(*ionet.APIError); ok {
|
||||
message := strings.TrimSpace(apiErr.Message)
|
||||
if message == "" {
|
||||
message = "failed to validate api key"
|
||||
}
|
||||
common.ApiErrorMsg(c, message)
|
||||
return
|
||||
}
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
totalHardware := 0
|
||||
totalAvailable := 0
|
||||
if result != nil {
|
||||
totalHardware = len(result.Hardware)
|
||||
totalAvailable = result.Total
|
||||
if totalAvailable == 0 {
|
||||
for _, hw := range result.Hardware {
|
||||
totalAvailable += hw.Available
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, gin.H{
|
||||
"hardware_count": totalHardware,
|
||||
"total_available": totalAvailable,
|
||||
})
|
||||
}
|
||||
|
||||
func requireDeploymentID(c *gin.Context) (string, bool) {
|
||||
deploymentID := strings.TrimSpace(c.Param("id"))
|
||||
if deploymentID == "" {
|
||||
common.ApiErrorMsg(c, "deployment ID is required")
|
||||
return "", false
|
||||
}
|
||||
return deploymentID, true
|
||||
}
|
||||
|
||||
func requireContainerID(c *gin.Context) (string, bool) {
|
||||
containerID := strings.TrimSpace(c.Param("container_id"))
|
||||
if containerID == "" {
|
||||
common.ApiErrorMsg(c, "container ID is required")
|
||||
return "", false
|
||||
}
|
||||
return containerID, true
|
||||
}
|
||||
|
||||
func mapIoNetDeployment(d ionet.Deployment) map[string]interface{} {
|
||||
var created int64
|
||||
if d.CreatedAt.IsZero() {
|
||||
created = time.Now().Unix()
|
||||
} else {
|
||||
created = d.CreatedAt.Unix()
|
||||
}
|
||||
|
||||
timeRemainingHours := d.ComputeMinutesRemaining / 60
|
||||
timeRemainingMins := d.ComputeMinutesRemaining % 60
|
||||
var timeRemaining string
|
||||
if timeRemainingHours > 0 {
|
||||
timeRemaining = fmt.Sprintf("%d hour %d minutes", timeRemainingHours, timeRemainingMins)
|
||||
} else if timeRemainingMins > 0 {
|
||||
timeRemaining = fmt.Sprintf("%d minutes", timeRemainingMins)
|
||||
} else {
|
||||
timeRemaining = "completed"
|
||||
}
|
||||
|
||||
hardwareInfo := fmt.Sprintf("%s %s x%d", d.BrandName, d.HardwareName, d.HardwareQuantity)
|
||||
|
||||
return map[string]interface{}{
|
||||
"id": d.ID,
|
||||
"deployment_name": d.Name,
|
||||
"container_name": d.Name,
|
||||
"status": strings.ToLower(d.Status),
|
||||
"type": "Container",
|
||||
"time_remaining": timeRemaining,
|
||||
"time_remaining_minutes": d.ComputeMinutesRemaining,
|
||||
"hardware_info": hardwareInfo,
|
||||
"hardware_name": d.HardwareName,
|
||||
"brand_name": d.BrandName,
|
||||
"hardware_quantity": d.HardwareQuantity,
|
||||
"completed_percent": d.CompletedPercent,
|
||||
"compute_minutes_served": d.ComputeMinutesServed,
|
||||
"compute_minutes_remaining": d.ComputeMinutesRemaining,
|
||||
"created_at": created,
|
||||
"updated_at": created,
|
||||
"model_name": "",
|
||||
"model_version": "",
|
||||
"instance_count": d.HardwareQuantity,
|
||||
"resource_config": map[string]interface{}{
|
||||
"cpu": "",
|
||||
"memory": "",
|
||||
"gpu": strconv.Itoa(d.HardwareQuantity),
|
||||
},
|
||||
"description": "",
|
||||
"provider": "io.net",
|
||||
}
|
||||
}
|
||||
|
||||
func computeStatusCounts(total int, deployments []ionet.Deployment) map[string]int64 {
|
||||
counts := map[string]int64{
|
||||
"all": int64(total),
|
||||
}
|
||||
|
||||
for _, status := range []string{"running", "completed", "failed", "deployment requested", "termination requested", "destroyed"} {
|
||||
counts[status] = 0
|
||||
}
|
||||
|
||||
for _, d := range deployments {
|
||||
status := strings.ToLower(strings.TrimSpace(d.Status))
|
||||
counts[status] = counts[status] + 1
|
||||
}
|
||||
|
||||
return counts
|
||||
}
|
||||
|
||||
func GetAllDeployments(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
status := c.Query("status")
|
||||
opts := &ionet.ListDeploymentsOptions{
|
||||
Status: strings.ToLower(strings.TrimSpace(status)),
|
||||
Page: pageInfo.GetPage(),
|
||||
PageSize: pageInfo.GetPageSize(),
|
||||
SortBy: "created_at",
|
||||
SortOrder: "desc",
|
||||
}
|
||||
|
||||
dl, err := client.ListDeployments(opts)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
items := make([]map[string]interface{}, 0, len(dl.Deployments))
|
||||
for _, d := range dl.Deployments {
|
||||
items = append(items, mapIoNetDeployment(d))
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"total": dl.Total,
|
||||
"items": items,
|
||||
"status_counts": computeStatusCounts(dl.Total, dl.Deployments),
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func SearchDeployments(c *gin.Context) {
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
status := strings.ToLower(strings.TrimSpace(c.Query("status")))
|
||||
keyword := strings.TrimSpace(c.Query("keyword"))
|
||||
|
||||
dl, err := client.ListDeployments(&ionet.ListDeploymentsOptions{
|
||||
Status: status,
|
||||
Page: pageInfo.GetPage(),
|
||||
PageSize: pageInfo.GetPageSize(),
|
||||
SortBy: "created_at",
|
||||
SortOrder: "desc",
|
||||
})
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
filtered := make([]ionet.Deployment, 0, len(dl.Deployments))
|
||||
if keyword == "" {
|
||||
filtered = dl.Deployments
|
||||
} else {
|
||||
kw := strings.ToLower(keyword)
|
||||
for _, d := range dl.Deployments {
|
||||
if strings.Contains(strings.ToLower(d.Name), kw) {
|
||||
filtered = append(filtered, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
items := make([]map[string]interface{}, 0, len(filtered))
|
||||
for _, d := range filtered {
|
||||
items = append(items, mapIoNetDeployment(d))
|
||||
}
|
||||
|
||||
total := dl.Total
|
||||
if keyword != "" {
|
||||
total = len(filtered)
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"page": pageInfo.GetPage(),
|
||||
"page_size": pageInfo.GetPageSize(),
|
||||
"total": total,
|
||||
"items": items,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func GetDeployment(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
details, err := client.GetDeployment(deploymentID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"id": details.ID,
|
||||
"deployment_name": details.ID,
|
||||
"model_name": "",
|
||||
"model_version": "",
|
||||
"status": strings.ToLower(details.Status),
|
||||
"instance_count": details.TotalContainers,
|
||||
"hardware_id": details.HardwareID,
|
||||
"resource_config": map[string]interface{}{
|
||||
"cpu": "",
|
||||
"memory": "",
|
||||
"gpu": strconv.Itoa(details.TotalGPUs),
|
||||
},
|
||||
"created_at": details.CreatedAt.Unix(),
|
||||
"updated_at": details.CreatedAt.Unix(),
|
||||
"description": "",
|
||||
"amount_paid": details.AmountPaid,
|
||||
"completed_percent": details.CompletedPercent,
|
||||
"gpus_per_container": details.GPUsPerContainer,
|
||||
"total_gpus": details.TotalGPUs,
|
||||
"total_containers": details.TotalContainers,
|
||||
"hardware_name": details.HardwareName,
|
||||
"brand_name": details.BrandName,
|
||||
"compute_minutes_served": details.ComputeMinutesServed,
|
||||
"compute_minutes_remaining": details.ComputeMinutesRemaining,
|
||||
"locations": details.Locations,
|
||||
"container_config": details.ContainerConfig,
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func UpdateDeploymentName(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
updateReq := &ionet.UpdateClusterNameRequest{
|
||||
Name: strings.TrimSpace(req.Name),
|
||||
}
|
||||
|
||||
if updateReq.Name == "" {
|
||||
common.ApiErrorMsg(c, "deployment name cannot be empty")
|
||||
return
|
||||
}
|
||||
|
||||
available, err := client.CheckClusterNameAvailability(updateReq.Name)
|
||||
if err != nil {
|
||||
common.ApiError(c, fmt.Errorf("failed to check name availability: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if !available {
|
||||
common.ApiErrorMsg(c, "deployment name is not available, please choose a different name")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.UpdateClusterName(deploymentID, updateReq)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"status": resp.Status,
|
||||
"message": resp.Message,
|
||||
"id": deploymentID,
|
||||
"name": updateReq.Name,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func UpdateDeployment(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req ionet.UpdateDeploymentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.UpdateDeployment(deploymentID, &req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"status": resp.Status,
|
||||
"deployment_id": resp.DeploymentID,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func ExtendDeployment(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req ionet.ExtendDurationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
details, err := client.ExtendDeployment(deploymentID, &req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := mapIoNetDeployment(ionet.Deployment{
|
||||
ID: details.ID,
|
||||
Status: details.Status,
|
||||
Name: deploymentID,
|
||||
CompletedPercent: float64(details.CompletedPercent),
|
||||
HardwareQuantity: details.TotalGPUs,
|
||||
BrandName: details.BrandName,
|
||||
HardwareName: details.HardwareName,
|
||||
ComputeMinutesServed: details.ComputeMinutesServed,
|
||||
ComputeMinutesRemaining: details.ComputeMinutesRemaining,
|
||||
CreatedAt: details.CreatedAt,
|
||||
})
|
||||
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func DeleteDeployment(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.DeleteDeployment(deploymentID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"status": resp.Status,
|
||||
"deployment_id": resp.DeploymentID,
|
||||
"message": "Deployment termination requested successfully",
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func CreateDeployment(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req ionet.DeploymentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.DeployContainer(&req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"deployment_id": resp.DeploymentID,
|
||||
"status": resp.Status,
|
||||
"message": "Deployment created successfully",
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func GetHardwareTypes(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
hardwareTypes, totalAvailable, err := client.ListHardwareTypes()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"hardware_types": hardwareTypes,
|
||||
"total": len(hardwareTypes),
|
||||
"total_available": totalAvailable,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func GetLocations(c *gin.Context) {
|
||||
client, ok := getIoClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
locationsResp, err := client.ListLocations()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
total := locationsResp.Total
|
||||
if total == 0 {
|
||||
total = len(locationsResp.Locations)
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"locations": locationsResp.Locations,
|
||||
"total": total,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func GetAvailableReplicas(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
hardwareIDStr := c.Query("hardware_id")
|
||||
gpuCountStr := c.Query("gpu_count")
|
||||
|
||||
if hardwareIDStr == "" {
|
||||
common.ApiErrorMsg(c, "hardware_id parameter is required")
|
||||
return
|
||||
}
|
||||
|
||||
hardwareID, err := strconv.Atoi(hardwareIDStr)
|
||||
if err != nil || hardwareID <= 0 {
|
||||
common.ApiErrorMsg(c, "invalid hardware_id parameter")
|
||||
return
|
||||
}
|
||||
|
||||
gpuCount := 1
|
||||
if gpuCountStr != "" {
|
||||
if parsed, err := strconv.Atoi(gpuCountStr); err == nil && parsed > 0 {
|
||||
gpuCount = parsed
|
||||
}
|
||||
}
|
||||
|
||||
replicas, err := client.GetAvailableReplicas(hardwareID, gpuCount)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, replicas)
|
||||
}
|
||||
|
||||
func GetPriceEstimation(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req ionet.PriceEstimationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
priceResp, err := client.GetPriceEstimation(&req)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, priceResp)
|
||||
}
|
||||
|
||||
func CheckClusterNameAvailability(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
clusterName := strings.TrimSpace(c.Query("name"))
|
||||
if clusterName == "" {
|
||||
common.ApiErrorMsg(c, "name parameter is required")
|
||||
return
|
||||
}
|
||||
|
||||
available, err := client.CheckClusterNameAvailability(clusterName)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"available": available,
|
||||
"name": clusterName,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
func GetDeploymentLogs(c *gin.Context) {
|
||||
client, ok := getIoClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
containerID := c.Query("container_id")
|
||||
if containerID == "" {
|
||||
common.ApiErrorMsg(c, "container_id parameter is required")
|
||||
return
|
||||
}
|
||||
level := c.Query("level")
|
||||
stream := c.Query("stream")
|
||||
cursor := c.Query("cursor")
|
||||
limitStr := c.Query("limit")
|
||||
follow := c.Query("follow") == "true"
|
||||
|
||||
var limit int = 100
|
||||
if limitStr != "" {
|
||||
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
|
||||
limit = parsedLimit
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
opts := &ionet.GetLogsOptions{
|
||||
Level: level,
|
||||
Stream: stream,
|
||||
Limit: limit,
|
||||
Cursor: cursor,
|
||||
Follow: follow,
|
||||
}
|
||||
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse(time.RFC3339, startTime); err == nil {
|
||||
opts.StartTime = &t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse(time.RFC3339, endTime); err == nil {
|
||||
opts.EndTime = &t
|
||||
}
|
||||
}
|
||||
|
||||
rawLogs, err := client.GetContainerLogsRaw(deploymentID, containerID, opts)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, rawLogs)
|
||||
}
|
||||
|
||||
func ListDeploymentContainers(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
containers, err := client.ListContainers(deploymentID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
items := make([]map[string]interface{}, 0)
|
||||
if containers != nil {
|
||||
items = make([]map[string]interface{}, 0, len(containers.Workers))
|
||||
for _, ctr := range containers.Workers {
|
||||
events := make([]map[string]interface{}, 0, len(ctr.ContainerEvents))
|
||||
for _, event := range ctr.ContainerEvents {
|
||||
events = append(events, map[string]interface{}{
|
||||
"time": event.Time.Unix(),
|
||||
"message": event.Message,
|
||||
})
|
||||
}
|
||||
|
||||
items = append(items, map[string]interface{}{
|
||||
"container_id": ctr.ContainerID,
|
||||
"device_id": ctr.DeviceID,
|
||||
"status": strings.ToLower(strings.TrimSpace(ctr.Status)),
|
||||
"hardware": ctr.Hardware,
|
||||
"brand_name": ctr.BrandName,
|
||||
"created_at": ctr.CreatedAt.Unix(),
|
||||
"uptime_percent": ctr.UptimePercent,
|
||||
"gpus_per_container": ctr.GPUsPerContainer,
|
||||
"public_url": ctr.PublicURL,
|
||||
"events": events,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"total": 0,
|
||||
"containers": items,
|
||||
}
|
||||
if containers != nil {
|
||||
response["total"] = containers.Total
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, response)
|
||||
}
|
||||
|
||||
func GetContainerDetails(c *gin.Context) {
|
||||
client, ok := getIoEnterpriseClient(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
deploymentID, ok := requireDeploymentID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
containerID, ok := requireContainerID(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
details, err := client.GetContainerDetails(deploymentID, containerID)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if details == nil {
|
||||
common.ApiErrorMsg(c, "container details not found")
|
||||
return
|
||||
}
|
||||
|
||||
events := make([]map[string]interface{}, 0, len(details.ContainerEvents))
|
||||
for _, event := range details.ContainerEvents {
|
||||
events = append(events, map[string]interface{}{
|
||||
"time": event.Time.Unix(),
|
||||
"message": event.Message,
|
||||
})
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"deployment_id": deploymentID,
|
||||
"container_id": details.ContainerID,
|
||||
"device_id": details.DeviceID,
|
||||
"status": strings.ToLower(strings.TrimSpace(details.Status)),
|
||||
"hardware": details.Hardware,
|
||||
"brand_name": details.BrandName,
|
||||
"created_at": details.CreatedAt.Unix(),
|
||||
"uptime_percent": details.UptimePercent,
|
||||
"gpus_per_container": details.GPUsPerContainer,
|
||||
"public_url": details.PublicURL,
|
||||
"events": events,
|
||||
}
|
||||
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
@@ -114,6 +114,7 @@ func GetStatus(c *gin.Context) {
|
||||
"setup": constant.Setup,
|
||||
"user_agreement_enabled": legalSetting.UserAgreement != "",
|
||||
"privacy_policy_enabled": legalSetting.PrivacyPolicy != "",
|
||||
"checkin_enabled": operation_setting.GetCheckinSetting().Enabled,
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
|
||||
@@ -249,7 +249,9 @@ func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, v
|
||||
return 0
|
||||
}
|
||||
|
||||
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||
// SyncUpstreamModels 同步上游模型与供应商:
|
||||
// - 默认仅创建「未配置模型」
|
||||
// - 可通过 overwrite 选择性覆盖更新本地已有模型的字段(前提:sync_official <> 0)
|
||||
func SyncUpstreamModels(c *gin.Context) {
|
||||
var req syncRequest
|
||||
// 允许空体
|
||||
@@ -260,12 +262,26 @@ func SyncUpstreamModels(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(missing) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||
"created_models": 0,
|
||||
"created_vendors": 0,
|
||||
"skipped_models": []string{},
|
||||
}})
|
||||
|
||||
// 若既无缺失模型需要创建,也未指定覆盖更新字段,则无需请求上游数据,直接返回
|
||||
if len(missing) == 0 && len(req.Overwrite) == 0 {
|
||||
modelsURL, vendorsURL := getUpstreamURLs(req.Locale)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"created_models": 0,
|
||||
"created_vendors": 0,
|
||||
"updated_models": 0,
|
||||
"skipped_models": []string{},
|
||||
"created_list": []string{},
|
||||
"updated_list": []string{},
|
||||
"source": gin.H{
|
||||
"locale": req.Locale,
|
||||
"models_url": modelsURL,
|
||||
"vendors_url": vendorsURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -315,9 +331,9 @@ func SyncUpstreamModels(c *gin.Context) {
|
||||
createdModels := 0
|
||||
createdVendors := 0
|
||||
updatedModels := 0
|
||||
var skipped []string
|
||||
var createdList []string
|
||||
var updatedList []string
|
||||
skipped := make([]string, 0)
|
||||
createdList := make([]string, 0)
|
||||
updatedList := make([]string, 0)
|
||||
|
||||
// 本地缓存:vendorName -> id
|
||||
vendorIDCache := make(map[string]int)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting"
|
||||
"github.com/QuantumNous/new-api/setting/console_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
@@ -20,7 +21,11 @@ func GetOptions(c *gin.Context) {
|
||||
var options []*model.Option
|
||||
common.OptionMapRWMutex.Lock()
|
||||
for k, v := range common.OptionMap {
|
||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
|
||||
if strings.HasSuffix(k, "Token") ||
|
||||
strings.HasSuffix(k, "Secret") ||
|
||||
strings.HasSuffix(k, "Key") ||
|
||||
strings.HasSuffix(k, "secret") ||
|
||||
strings.HasSuffix(k, "api_key") {
|
||||
continue
|
||||
}
|
||||
options = append(options, &model.Option{
|
||||
@@ -173,6 +178,15 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "AutomaticDisableStatusCodes":
|
||||
_, err = operation_setting.ParseHTTPStatusCodeRanges(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
|
||||
if err != nil {
|
||||
|
||||
@@ -348,7 +348,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelType, err) && channelError.AutoBan {
|
||||
gopool.Go(func() {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
service.DisableChannel(channelError, err.ErrorWithStatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -378,7 +378,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
other["admin_info"] = adminInfo
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
|
||||
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -74,7 +74,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
key := channel.Key
|
||||
|
||||
privateData := task.PrivateData
|
||||
if privateData.Key != "" {
|
||||
key = privateData.Key
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
}, proxy)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -149,6 +150,24 @@ func AddToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// 非无限额度时,检查额度值是否超出有效范围
|
||||
if !token.UnlimitedQuota {
|
||||
if token.RemainQuota < 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "额度值不能为负数",
|
||||
})
|
||||
return
|
||||
}
|
||||
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
|
||||
if token.RemainQuota > maxQuotaValue {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
if token.RemainQuota < 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "额度值不能为负数",
|
||||
})
|
||||
return
|
||||
}
|
||||
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
|
||||
if token.RemainQuota > maxQuotaValue {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
|
||||
@@ -110,18 +110,17 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
cleanUser := model.User{
|
||||
Id: user.Id,
|
||||
Username: user.Username,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
Group: user.Group,
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "",
|
||||
"success": true,
|
||||
"data": cleanUser,
|
||||
"data": map[string]any{
|
||||
"id": user.Id,
|
||||
"username": user.Username,
|
||||
"display_name": user.DisplayName,
|
||||
"role": user.Role,
|
||||
"status": user.Status,
|
||||
"group": user.Group,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -764,7 +763,10 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
|
||||
|
||||
// 密码不为空,需要验证原密码
|
||||
// 支持第一次账号绑定时原密码为空的情况
|
||||
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) && currentUser.Password != "" {
|
||||
err = fmt.Errorf("原密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
7
docs/ionet-client.md
Normal file
7
docs/ionet-client.md
Normal file
@@ -0,0 +1,7 @@
|
||||
Request URL
|
||||
https://api.io.solutions/v1/io-cloud/clusters/654fc0a9-0d4a-4db4-9b95-3f56189348a2/update-name
|
||||
Request Method
|
||||
PUT
|
||||
|
||||
{"status":"succeeded","message":"Cluster name updated successfully"}
|
||||
|
||||
@@ -26,6 +26,7 @@ type GeneralErrorResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
|
||||
105
dto/gemini.go
105
dto/gemini.go
@@ -22,6 +22,27 @@ type GeminiChatRequest struct {
|
||||
CachedContent string `json:"cachedContent,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiChatRequest to accept both snake_case and camelCase fields.
|
||||
func (r *GeminiChatRequest) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiChatRequest
|
||||
var aux struct {
|
||||
Alias
|
||||
SystemInstructionSnake *GeminiChatContent `json:"system_instruction,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*r = GeminiChatRequest(aux.Alias)
|
||||
|
||||
if aux.SystemInstructionSnake != nil {
|
||||
r.SystemInstructions = aux.SystemInstructionSnake
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ToolConfig struct {
|
||||
FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
|
||||
RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
|
||||
@@ -105,7 +126,7 @@ func (r *GeminiChatRequest) SetModelName(modelName string) {
|
||||
|
||||
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
|
||||
var tools []GeminiChatTool
|
||||
if strings.HasSuffix(string(r.Tools), "[") {
|
||||
if strings.HasPrefix(string(r.Tools), "[") {
|
||||
// is array
|
||||
if err := common.Unmarshal(r.Tools, &tools); err != nil {
|
||||
logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())
|
||||
@@ -320,6 +341,88 @@ type GeminiChatGenerationConfig struct {
|
||||
ImageConfig json.RawMessage `json:"imageConfig,omitempty"` // RawMessage to allow flexible image config
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiChatGenerationConfig to accept both snake_case and camelCase fields.
|
||||
func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiChatGenerationConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
TopPSnake float64 `json:"top_p,omitempty"`
|
||||
TopKSnake float64 `json:"top_k,omitempty"`
|
||||
MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
|
||||
CandidateCountSnake int `json:"candidate_count,omitempty"`
|
||||
StopSequencesSnake []string `json:"stop_sequences,omitempty"`
|
||||
ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
|
||||
ResponseSchemaSnake any `json:"response_schema,omitempty"`
|
||||
ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
|
||||
PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
|
||||
ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
|
||||
MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
|
||||
ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
|
||||
ThinkingConfigSnake *GeminiThinkingConfig `json:"thinking_config,omitempty"`
|
||||
SpeechConfigSnake json.RawMessage `json:"speech_config,omitempty"`
|
||||
ImageConfigSnake json.RawMessage `json:"image_config,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GeminiChatGenerationConfig(aux.Alias)
|
||||
|
||||
// Prioritize snake_case if present
|
||||
if aux.TopPSnake != 0 {
|
||||
c.TopP = aux.TopPSnake
|
||||
}
|
||||
if aux.TopKSnake != 0 {
|
||||
c.TopK = aux.TopKSnake
|
||||
}
|
||||
if aux.MaxOutputTokensSnake != 0 {
|
||||
c.MaxOutputTokens = aux.MaxOutputTokensSnake
|
||||
}
|
||||
if aux.CandidateCountSnake != 0 {
|
||||
c.CandidateCount = aux.CandidateCountSnake
|
||||
}
|
||||
if len(aux.StopSequencesSnake) > 0 {
|
||||
c.StopSequences = aux.StopSequencesSnake
|
||||
}
|
||||
if aux.ResponseMimeTypeSnake != "" {
|
||||
c.ResponseMimeType = aux.ResponseMimeTypeSnake
|
||||
}
|
||||
if aux.ResponseSchemaSnake != nil {
|
||||
c.ResponseSchema = aux.ResponseSchemaSnake
|
||||
}
|
||||
if len(aux.ResponseJsonSchemaSnake) > 0 {
|
||||
c.ResponseJsonSchema = aux.ResponseJsonSchemaSnake
|
||||
}
|
||||
if aux.PresencePenaltySnake != nil {
|
||||
c.PresencePenalty = aux.PresencePenaltySnake
|
||||
}
|
||||
if aux.FrequencyPenaltySnake != nil {
|
||||
c.FrequencyPenalty = aux.FrequencyPenaltySnake
|
||||
}
|
||||
if aux.ResponseLogprobsSnake {
|
||||
c.ResponseLogprobs = aux.ResponseLogprobsSnake
|
||||
}
|
||||
if aux.MediaResolutionSnake != "" {
|
||||
c.MediaResolution = aux.MediaResolutionSnake
|
||||
}
|
||||
if len(aux.ResponseModalitiesSnake) > 0 {
|
||||
c.ResponseModalities = aux.ResponseModalitiesSnake
|
||||
}
|
||||
if aux.ThinkingConfigSnake != nil {
|
||||
c.ThinkingConfig = aux.ThinkingConfigSnake
|
||||
}
|
||||
if len(aux.SpeechConfigSnake) > 0 {
|
||||
c.SpeechConfig = aux.SpeechConfigSnake
|
||||
}
|
||||
if len(aux.ImageConfigSnake) > 0 {
|
||||
c.ImageConfig = aux.ImageConfigSnake
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type MediaResolution string
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
||||
@@ -167,9 +167,9 @@ func (i *ImageRequest) SetModelName(modelName string) {
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
Extra any `json:"extra,omitempty"`
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
type ImageData struct {
|
||||
Url string `json:"url"`
|
||||
|
||||
@@ -23,6 +23,8 @@ type FormatJsonSchema struct {
|
||||
Strict json.RawMessage `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs.
|
||||
// 参数增加规范:无引用的参数必须使用json.RawMessage类型,并添加omitempty标签
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
@@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct {
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
|
||||
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
|
||||
EnableSearch json.RawMessage `json:"enable_search,omitempty"`
|
||||
// ollama Params
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
// baidu v2
|
||||
@@ -805,11 +808,11 @@ type OpenAIResponsesRequest struct {
|
||||
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
|
||||
@@ -334,13 +334,16 @@ type IncompleteDetails struct {
|
||||
}
|
||||
|
||||
type ResponsesOutput struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []ResponsesOutputContent `json:"content"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []ResponsesOutputContent `json:"content"`
|
||||
Quality string `json:"quality"`
|
||||
Size string `json:"size"`
|
||||
CallId string `json:"call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type ResponsesOutputContent struct {
|
||||
|
||||
8
go.mod
8
go.mod
@@ -27,6 +27,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/grafana/pyroscope-go v1.2.7
|
||||
github.com/jfreymuth/oggvorbis v1.0.5
|
||||
github.com/jinzhu/copier v0.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
@@ -36,6 +37,7 @@ require (
|
||||
github.com/samber/lo v1.52.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/stripe/stripe-go/v81 v81.4.0
|
||||
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300
|
||||
github.com/thanhpk/randstr v1.0.6
|
||||
@@ -62,6 +64,7 @@ require (
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
@@ -77,11 +80,11 @@ require (
|
||||
github.com/go-sql-driver/mysql v1.7.0 // indirect
|
||||
github.com/go-webauthn/x v0.1.25 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/go-tpm v0.9.5 // indirect
|
||||
github.com/gorilla/context v1.1.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
||||
github.com/icza/bitio v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -91,6 +94,7 @@ require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.17.8 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
@@ -101,7 +105,9 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
|
||||
48
go.sum
48
go.sum
@@ -118,9 +118,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
|
||||
github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -132,6 +131,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grafana/pyroscope-go v1.2.7 h1:VWBBlqxjyR0Cwk2W6UrE8CdcdD80GOFNutj0Kb1T8ac=
|
||||
github.com/grafana/pyroscope-go v1.2.7/go.mod h1:o/bpSLiJYYP6HQtvcoVKiE9s5RiNgjYTj1DhiddP2Pc=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 h1:c1Us8i6eSmkW+Ez05d3co8kasnuOY813tbMN8i/a3Og=
|
||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
|
||||
github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
|
||||
github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
|
||||
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
|
||||
@@ -160,12 +163,15 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
|
||||
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
@@ -214,14 +220,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
|
||||
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
|
||||
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
|
||||
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
|
||||
@@ -231,6 +234,7 @@ github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+D
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
@@ -288,12 +292,12 @@ golang.org/x/arch v0.21.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
@@ -321,6 +325,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
@@ -350,19 +356,29 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
|
||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
|
||||
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
||||
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
|
||||
modernc.org/cc/v4 v4.26.5/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.1 h1:wPKYn5EC/mYTqBO373jKjvX2n+3+aK7+sICCv4Fjy1A=
|
||||
modernc.org/ccgo/v4 v4.28.1/go.mod h1:uD+4RnfrVgE6ec9NGguUNdhqzNIeeomeXf6CL0GTE5Q=
|
||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A=
|
||||
modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I=
|
||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
7
main.go
7
main.go
@@ -124,6 +124,11 @@ func main() {
|
||||
common.SysLog("pprof enabled")
|
||||
}
|
||||
|
||||
err = common.StartPyroScope()
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("start pyroscope error : %v", err))
|
||||
}
|
||||
|
||||
// Initialize HTTP server
|
||||
server := gin.New()
|
||||
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
||||
@@ -183,6 +188,7 @@ func InjectUmamiAnalytics() {
|
||||
analyticsInjectBuilder.WriteString(umamiSiteID)
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
@@ -204,6 +210,7 @@ func InjectGoogleAnalytics() {
|
||||
analyticsInjectBuilder.WriteString("');")
|
||||
analyticsInjectBuilder.WriteString("</script>")
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
|
||||
@@ -195,8 +195,8 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||
}
|
||||
// 检查path包含/v1/messages
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
|
||||
// 检查path包含/v1/messages 或 /v1/models
|
||||
if strings.Contains(c.Request.URL.Path, "/v1/messages") || strings.Contains(c.Request.URL.Path, "/v1/models") {
|
||||
anthropicKey := c.Request.Header.Get("x-api-key")
|
||||
if anthropicKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
|
||||
@@ -218,10 +218,14 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
parts := make([]string, 0)
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
|
||||
key = strings.TrimSpace(key[7:])
|
||||
}
|
||||
if key == "" || key == "midjourney-proxy" {
|
||||
key = c.Request.Header.Get("mj-api-secret")
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
|
||||
key = strings.TrimSpace(key[7:])
|
||||
}
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts = strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
|
||||
179
model/checkin.go
Normal file
179
model/checkin.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Checkin 签到记录
|
||||
type Checkin struct {
|
||||
Id int `json:"id" gorm:"primaryKey;autoIncrement"`
|
||||
UserId int `json:"user_id" gorm:"not null;uniqueIndex:idx_user_checkin_date"`
|
||||
CheckinDate string `json:"checkin_date" gorm:"type:varchar(10);not null;uniqueIndex:idx_user_checkin_date"` // 格式: YYYY-MM-DD
|
||||
QuotaAwarded int `json:"quota_awarded" gorm:"not null"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint"`
|
||||
}
|
||||
|
||||
// CheckinRecord 用于API返回的签到记录(不包含敏感字段)
|
||||
type CheckinRecord struct {
|
||||
CheckinDate string `json:"checkin_date"`
|
||||
QuotaAwarded int `json:"quota_awarded"`
|
||||
}
|
||||
|
||||
func (Checkin) TableName() string {
|
||||
return "checkins"
|
||||
}
|
||||
|
||||
// GetUserCheckinRecords 获取用户在指定日期范围内的签到记录
|
||||
func GetUserCheckinRecords(userId int, startDate, endDate string) ([]Checkin, error) {
|
||||
var records []Checkin
|
||||
err := DB.Where("user_id = ? AND checkin_date >= ? AND checkin_date <= ?",
|
||||
userId, startDate, endDate).
|
||||
Order("checkin_date DESC").
|
||||
Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// HasCheckedInToday 检查用户今天是否已签到
|
||||
func HasCheckedInToday(userId int) (bool, error) {
|
||||
today := time.Now().Format("2006-01-02")
|
||||
var count int64
|
||||
err := DB.Model(&Checkin{}).
|
||||
Where("user_id = ? AND checkin_date = ?", userId, today).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// UserCheckin 执行用户签到
|
||||
// MySQL 和 PostgreSQL 使用事务保证原子性
|
||||
// SQLite 不支持嵌套事务,使用顺序操作 + 手动回滚
|
||||
func UserCheckin(userId int) (*Checkin, error) {
|
||||
setting := operation_setting.GetCheckinSetting()
|
||||
if !setting.Enabled {
|
||||
return nil, errors.New("签到功能未启用")
|
||||
}
|
||||
|
||||
// 检查今天是否已签到
|
||||
hasChecked, err := HasCheckedInToday(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasChecked {
|
||||
return nil, errors.New("今日已签到")
|
||||
}
|
||||
|
||||
// 计算随机额度奖励
|
||||
quotaAwarded := setting.MinQuota
|
||||
if setting.MaxQuota > setting.MinQuota {
|
||||
quotaAwarded = setting.MinQuota + rand.Intn(setting.MaxQuota-setting.MinQuota+1)
|
||||
}
|
||||
|
||||
today := time.Now().Format("2006-01-02")
|
||||
checkin := &Checkin{
|
||||
UserId: userId,
|
||||
CheckinDate: today,
|
||||
QuotaAwarded: quotaAwarded,
|
||||
CreatedAt: time.Now().Unix(),
|
||||
}
|
||||
|
||||
// 根据数据库类型选择不同的策略
|
||||
if common.UsingSQLite {
|
||||
// SQLite 不支持嵌套事务,使用顺序操作 + 手动回滚
|
||||
return userCheckinWithoutTransaction(checkin, userId, quotaAwarded)
|
||||
}
|
||||
|
||||
// MySQL 和 PostgreSQL 支持事务,使用事务保证原子性
|
||||
return userCheckinWithTransaction(checkin, userId, quotaAwarded)
|
||||
}
|
||||
|
||||
// userCheckinWithTransaction 使用事务执行签到(适用于 MySQL 和 PostgreSQL)
|
||||
func userCheckinWithTransaction(checkin *Checkin, userId int, quotaAwarded int) (*Checkin, error) {
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
// 步骤1: 创建签到记录
|
||||
// 数据库有唯一约束 (user_id, checkin_date),可以防止并发重复签到
|
||||
if err := tx.Create(checkin).Error; err != nil {
|
||||
return errors.New("签到失败,请稍后重试")
|
||||
}
|
||||
|
||||
// 步骤2: 在事务中增加用户额度
|
||||
if err := tx.Model(&User{}).Where("id = ?", userId).
|
||||
Update("quota", gorm.Expr("quota + ?", quotaAwarded)).Error; err != nil {
|
||||
return errors.New("签到失败:更新额度出错")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 事务成功后,异步更新缓存
|
||||
go func() {
|
||||
_ = cacheIncrUserQuota(userId, int64(quotaAwarded))
|
||||
}()
|
||||
|
||||
return checkin, nil
|
||||
}
|
||||
|
||||
// userCheckinWithoutTransaction 不使用事务执行签到(适用于 SQLite)
|
||||
func userCheckinWithoutTransaction(checkin *Checkin, userId int, quotaAwarded int) (*Checkin, error) {
|
||||
// 步骤1: 创建签到记录
|
||||
// 数据库有唯一约束 (user_id, checkin_date),可以防止并发重复签到
|
||||
if err := DB.Create(checkin).Error; err != nil {
|
||||
return nil, errors.New("签到失败,请稍后重试")
|
||||
}
|
||||
|
||||
// 步骤2: 增加用户额度
|
||||
// 使用 db=true 强制直接写入数据库,不使用批量更新
|
||||
if err := IncreaseUserQuota(userId, quotaAwarded, true); err != nil {
|
||||
// 如果增加额度失败,需要回滚签到记录
|
||||
DB.Delete(checkin)
|
||||
return nil, errors.New("签到失败:更新额度出错")
|
||||
}
|
||||
|
||||
return checkin, nil
|
||||
}
|
||||
|
||||
// GetUserCheckinStats 获取用户签到统计信息
|
||||
func GetUserCheckinStats(userId int, month string) (map[string]interface{}, error) {
|
||||
// 获取指定月份的所有签到记录
|
||||
startDate := month + "-01"
|
||||
endDate := month + "-31"
|
||||
|
||||
records, err := GetUserCheckinRecords(userId, startDate, endDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为不包含敏感字段的记录
|
||||
checkinRecords := make([]CheckinRecord, len(records))
|
||||
for i, r := range records {
|
||||
checkinRecords[i] = CheckinRecord{
|
||||
CheckinDate: r.CheckinDate,
|
||||
QuotaAwarded: r.QuotaAwarded,
|
||||
}
|
||||
}
|
||||
|
||||
// 检查今天是否已签到
|
||||
hasCheckedToday, _ := HasCheckedInToday(userId)
|
||||
|
||||
// 获取用户所有时间的签到统计
|
||||
var totalCheckins int64
|
||||
var totalQuota int64
|
||||
DB.Model(&Checkin{}).Where("user_id = ?", userId).Count(&totalCheckins)
|
||||
DB.Model(&Checkin{}).Where("user_id = ?", userId).Select("COALESCE(SUM(quota_awarded), 0)").Scan(&totalQuota)
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_quota": totalQuota, // 所有时间累计获得的额度
|
||||
"total_checkins": totalCheckins, // 所有时间累计签到次数
|
||||
"checkin_count": len(records), // 本月签到次数
|
||||
"checked_in_today": hasCheckedToday, // 今天是否已签到
|
||||
"records": checkinRecords, // 本月签到记录详情(不含id和user_id)
|
||||
}, nil
|
||||
}
|
||||
@@ -267,6 +267,7 @@ func migrateDB() error {
|
||||
&Setup{},
|
||||
&TwoFA{},
|
||||
&TwoFABackupCode{},
|
||||
&Checkin{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -300,6 +301,7 @@ func migrateDBFast() error {
|
||||
{&Setup{}, "Setup"},
|
||||
{&TwoFA{}, "TwoFA"},
|
||||
{&TwoFABackupCode{}, "TwoFABackupCode"},
|
||||
{&Checkin{}, "Checkin"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
@@ -143,6 +143,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["AutomaticDisableStatusCodes"] = operation_setting.AutomaticDisableStatusCodesToString()
|
||||
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
@@ -444,6 +445,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.SensitiveWordsFromString(value)
|
||||
case "AutomaticDisableKeywords":
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "AutomaticDisableStatusCodes":
|
||||
err = operation_setting.AutomaticDisableStatusCodesFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
|
||||
@@ -26,7 +26,7 @@ type Token struct {
|
||||
AllowIps *string `json:"allow_ips" gorm:"default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
Group string `json:"group" gorm:"default:''"`
|
||||
CrossGroupRetry bool `json:"cross_group_retry" gorm:"default:false"` // 跨分组重试,仅auto分组有效
|
||||
CrossGroupRetry bool `json:"cross_group_retry"` // 跨分组重试,仅auto分组有效
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
|
||||
219
pkg/ionet/client.go
Normal file
219
pkg/ionet/client.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package ionet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultEnterpriseBaseURL = "https://api.io.solutions/enterprise/v1/io-cloud/caas"
|
||||
DefaultBaseURL = "https://api.io.solutions/v1/io-cloud/caas"
|
||||
DefaultTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// DefaultHTTPClient is the default HTTP client implementation
|
||||
type DefaultHTTPClient struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewDefaultHTTPClient creates a new default HTTP client
|
||||
func NewDefaultHTTPClient(timeout time.Duration) *DefaultHTTPClient {
|
||||
return &DefaultHTTPClient{
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Do executes an HTTP request
|
||||
func (c *DefaultHTTPClient) Do(req *HTTPRequest) (*HTTPResponse, error) {
|
||||
httpReq, err := http.NewRequest(req.Method, req.URL, bytes.NewReader(req.Body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
for key, value := range req.Headers {
|
||||
httpReq.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("HTTP request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
var body bytes.Buffer
|
||||
_, err = body.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Convert headers
|
||||
headers := make(map[string]string)
|
||||
for key, values := range resp.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
return &HTTPResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Headers: headers,
|
||||
Body: body.Bytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewEnterpriseClient creates a new IO.NET API client targeting the enterprise API base URL.
|
||||
func NewEnterpriseClient(apiKey string) *Client {
|
||||
return NewClientWithConfig(apiKey, DefaultEnterpriseBaseURL, nil)
|
||||
}
|
||||
|
||||
// NewClient creates a new IO.NET API client targeting the public API base URL.
|
||||
func NewClient(apiKey string) *Client {
|
||||
return NewClientWithConfig(apiKey, DefaultBaseURL, nil)
|
||||
}
|
||||
|
||||
// NewClientWithConfig creates a new IO.NET API client with custom configuration
|
||||
func NewClientWithConfig(apiKey, baseURL string, httpClient HTTPClient) *Client {
|
||||
if baseURL == "" {
|
||||
baseURL = DefaultBaseURL
|
||||
}
|
||||
if httpClient == nil {
|
||||
httpClient = NewDefaultHTTPClient(DefaultTimeout)
|
||||
}
|
||||
return &Client{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
// makeRequest performs an HTTP request and handles common response processing
|
||||
func (c *Client) makeRequest(method, endpoint string, body interface{}) (*HTTPResponse, error) {
|
||||
var reqBody []byte
|
||||
var err error
|
||||
|
||||
if body != nil {
|
||||
reqBody, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"X-API-KEY": c.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
req := &HTTPRequest{
|
||||
Method: method,
|
||||
URL: c.BaseURL + endpoint,
|
||||
Headers: headers,
|
||||
Body: reqBody,
|
||||
}
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
// Handle API errors
|
||||
if resp.StatusCode >= 400 {
|
||||
var apiErr APIError
|
||||
if len(resp.Body) > 0 {
|
||||
// Try to parse the actual error format: {"detail": "message"}
|
||||
var errorResp struct {
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body, &errorResp); err == nil && errorResp.Detail != "" {
|
||||
apiErr = APIError{
|
||||
Code: resp.StatusCode,
|
||||
Message: errorResp.Detail,
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw body as details
|
||||
apiErr = APIError{
|
||||
Code: resp.StatusCode,
|
||||
Message: fmt.Sprintf("API request failed with status %d", resp.StatusCode),
|
||||
Details: string(resp.Body),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
apiErr = APIError{
|
||||
Code: resp.StatusCode,
|
||||
Message: fmt.Sprintf("API request failed with status %d", resp.StatusCode),
|
||||
}
|
||||
}
|
||||
return nil, &apiErr
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// buildQueryParams builds query parameters for GET requests
|
||||
func buildQueryParams(params map[string]interface{}) string {
|
||||
if len(params) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
for key, value := range params {
|
||||
if value == nil {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
values.Add(key, v)
|
||||
}
|
||||
case int:
|
||||
if v != 0 {
|
||||
values.Add(key, strconv.Itoa(v))
|
||||
}
|
||||
case int64:
|
||||
if v != 0 {
|
||||
values.Add(key, strconv.FormatInt(v, 10))
|
||||
}
|
||||
case float64:
|
||||
if v != 0 {
|
||||
values.Add(key, strconv.FormatFloat(v, 'f', -1, 64))
|
||||
}
|
||||
case bool:
|
||||
values.Add(key, strconv.FormatBool(v))
|
||||
case time.Time:
|
||||
if !v.IsZero() {
|
||||
values.Add(key, v.Format(time.RFC3339))
|
||||
}
|
||||
case *time.Time:
|
||||
if v != nil && !v.IsZero() {
|
||||
values.Add(key, v.Format(time.RFC3339))
|
||||
}
|
||||
case []int:
|
||||
if len(v) > 0 {
|
||||
if encoded, err := json.Marshal(v); err == nil {
|
||||
values.Add(key, string(encoded))
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
if len(v) > 0 {
|
||||
if encoded, err := json.Marshal(v); err == nil {
|
||||
values.Add(key, string(encoded))
|
||||
}
|
||||
}
|
||||
default:
|
||||
values.Add(key, fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
return "?" + values.Encode()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
302
pkg/ionet/container.go
Normal file
302
pkg/ionet/container.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package ionet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// ListContainers retrieves all containers for a specific deployment
|
||||
func (c *Client) ListContainers(deploymentID string) (*ContainerList, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/containers", deploymentID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list containers: %w", err)
|
||||
}
|
||||
|
||||
var containerList ContainerList
|
||||
if err := decodeDataWithFlexibleTimes(resp.Body, &containerList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse containers list: %w", err)
|
||||
}
|
||||
|
||||
return &containerList, nil
|
||||
}
|
||||
|
||||
// GetContainerDetails retrieves detailed information about a specific container
|
||||
func (c *Client) GetContainerDetails(deploymentID, containerID string) (*Container, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return nil, fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/container/%s", deploymentID, containerID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get container details: %w", err)
|
||||
}
|
||||
|
||||
// API response format not documented, assuming direct format
|
||||
var container Container
|
||||
if err := decodeWithFlexibleTimes(resp.Body, &container); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse container details: %w", err)
|
||||
}
|
||||
|
||||
return &container, nil
|
||||
}
|
||||
|
||||
// GetContainerJobs retrieves containers jobs for a specific container (similar to containers endpoint)
|
||||
func (c *Client) GetContainerJobs(deploymentID, containerID string) (*ContainerList, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return nil, fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/containers-jobs/%s", deploymentID, containerID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get container jobs: %w", err)
|
||||
}
|
||||
|
||||
var containerList ContainerList
|
||||
if err := decodeDataWithFlexibleTimes(resp.Body, &containerList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse container jobs: %w", err)
|
||||
}
|
||||
|
||||
return &containerList, nil
|
||||
}
|
||||
|
||||
// buildLogEndpoint constructs the request path for fetching logs
|
||||
func buildLogEndpoint(deploymentID, containerID string, opts *GetLogsOptions) (string, error) {
|
||||
if deploymentID == "" {
|
||||
return "", fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return "", fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
|
||||
params := make(map[string]interface{})
|
||||
|
||||
if opts != nil {
|
||||
if opts.Level != "" {
|
||||
params["level"] = opts.Level
|
||||
}
|
||||
if opts.Stream != "" {
|
||||
params["stream"] = opts.Stream
|
||||
}
|
||||
if opts.Limit > 0 {
|
||||
params["limit"] = opts.Limit
|
||||
}
|
||||
if opts.Cursor != "" {
|
||||
params["cursor"] = opts.Cursor
|
||||
}
|
||||
if opts.Follow {
|
||||
params["follow"] = true
|
||||
}
|
||||
|
||||
if opts.StartTime != nil {
|
||||
params["start_time"] = opts.StartTime
|
||||
}
|
||||
if opts.EndTime != nil {
|
||||
params["end_time"] = opts.EndTime
|
||||
}
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/log/%s", deploymentID, containerID)
|
||||
endpoint += buildQueryParams(params)
|
||||
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
// GetContainerLogs retrieves logs for containers in a deployment and normalizes them
|
||||
func (c *Client) GetContainerLogs(deploymentID, containerID string, opts *GetLogsOptions) (*ContainerLogs, error) {
|
||||
raw, err := c.GetContainerLogsRaw(deploymentID, containerID, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logs := &ContainerLogs{
|
||||
ContainerID: containerID,
|
||||
}
|
||||
|
||||
if raw == "" {
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
normalized := strings.ReplaceAll(raw, "\r\n", "\n")
|
||||
lines := strings.Split(normalized, "\n")
|
||||
logs.Logs = lo.FilterMap(lines, func(line string, _ int) (LogEntry, bool) {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
return LogEntry{}, false
|
||||
}
|
||||
return LogEntry{Message: line}, true
|
||||
})
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
// GetContainerLogsRaw retrieves the raw text logs for a specific container
|
||||
func (c *Client) GetContainerLogsRaw(deploymentID, containerID string, opts *GetLogsOptions) (string, error) {
|
||||
endpoint, err := buildLogEndpoint(deploymentID, containerID, opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get container logs: %w", err)
|
||||
}
|
||||
|
||||
return string(resp.Body), nil
|
||||
}
|
||||
|
||||
// StreamContainerLogs streams real-time logs for a specific container
|
||||
// This method uses a callback function to handle incoming log entries
|
||||
func (c *Client) StreamContainerLogs(deploymentID, containerID string, opts *GetLogsOptions, callback func(*LogEntry) error) error {
|
||||
if deploymentID == "" {
|
||||
return fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
if callback == nil {
|
||||
return fmt.Errorf("callback function cannot be nil")
|
||||
}
|
||||
|
||||
// Set follow to true for streaming
|
||||
if opts == nil {
|
||||
opts = &GetLogsOptions{}
|
||||
}
|
||||
opts.Follow = true
|
||||
|
||||
endpoint, err := buildLogEndpoint(deploymentID, containerID, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: This is a simplified implementation. In a real scenario, you might want to use
|
||||
// Server-Sent Events (SSE) or WebSocket for streaming logs
|
||||
for {
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stream container logs: %w", err)
|
||||
}
|
||||
|
||||
var logs ContainerLogs
|
||||
if err := decodeWithFlexibleTimes(resp.Body, &logs); err != nil {
|
||||
return fmt.Errorf("failed to parse container logs: %w", err)
|
||||
}
|
||||
|
||||
// Call the callback for each log entry
|
||||
for _, logEntry := range logs.Logs {
|
||||
if err := callback(&logEntry); err != nil {
|
||||
return fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no more logs or we have a cursor, continue polling
|
||||
if !logs.HasMore && logs.NextCursor == "" {
|
||||
break
|
||||
}
|
||||
|
||||
// Update cursor for next request
|
||||
if logs.NextCursor != "" {
|
||||
opts.Cursor = logs.NextCursor
|
||||
endpoint, err = buildLogEndpoint(deploymentID, containerID, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Wait a bit before next poll to avoid overwhelming the API
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartContainer restarts a specific container (if supported by the API)
|
||||
func (c *Client) RestartContainer(deploymentID, containerID string) error {
|
||||
if deploymentID == "" {
|
||||
return fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/container/%s/restart", deploymentID, containerID)
|
||||
|
||||
_, err := c.makeRequest("POST", endpoint, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to restart container: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopContainer stops a specific container (if supported by the API)
|
||||
func (c *Client) StopContainer(deploymentID, containerID string) error {
|
||||
if deploymentID == "" {
|
||||
return fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/container/%s/stop", deploymentID, containerID)
|
||||
|
||||
_, err := c.makeRequest("POST", endpoint, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop container: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteInContainer executes a command in a specific container (if supported by the API)
|
||||
func (c *Client) ExecuteInContainer(deploymentID, containerID string, command []string) (string, error) {
|
||||
if deploymentID == "" {
|
||||
return "", fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if containerID == "" {
|
||||
return "", fmt.Errorf("container ID cannot be empty")
|
||||
}
|
||||
if len(command) == 0 {
|
||||
return "", fmt.Errorf("command cannot be empty")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"command": command,
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/container/%s/exec", deploymentID, containerID)
|
||||
|
||||
resp, err := c.makeRequest("POST", endpoint, reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to execute command in container: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body, &result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse execution result: %w", err)
|
||||
}
|
||||
|
||||
if output, ok := result["output"].(string); ok {
|
||||
return output, nil
|
||||
}
|
||||
|
||||
return string(resp.Body), nil
|
||||
}
|
||||
377
pkg/ionet/deployment.go
Normal file
377
pkg/ionet/deployment.go
Normal file
@@ -0,0 +1,377 @@
|
||||
package ionet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// DeployContainer deploys a new container with the specified configuration
|
||||
func (c *Client) DeployContainer(req *DeploymentRequest) (*DeploymentResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("deployment request cannot be nil")
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.ResourcePrivateName == "" {
|
||||
return nil, fmt.Errorf("resource_private_name is required")
|
||||
}
|
||||
if len(req.LocationIDs) == 0 {
|
||||
return nil, fmt.Errorf("location_ids is required")
|
||||
}
|
||||
if req.HardwareID <= 0 {
|
||||
return nil, fmt.Errorf("hardware_id is required")
|
||||
}
|
||||
if req.RegistryConfig.ImageURL == "" {
|
||||
return nil, fmt.Errorf("registry_config.image_url is required")
|
||||
}
|
||||
if req.GPUsPerContainer < 1 {
|
||||
return nil, fmt.Errorf("gpus_per_container must be at least 1")
|
||||
}
|
||||
if req.DurationHours < 1 {
|
||||
return nil, fmt.Errorf("duration_hours must be at least 1")
|
||||
}
|
||||
if req.ContainerConfig.ReplicaCount < 1 {
|
||||
return nil, fmt.Errorf("container_config.replica_count must be at least 1")
|
||||
}
|
||||
|
||||
resp, err := c.makeRequest("POST", "/deploy", req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to deploy container: %w", err)
|
||||
}
|
||||
|
||||
// API returns direct format:
|
||||
// {"status": "string", "deployment_id": "..."}
|
||||
var deployResp DeploymentResponse
|
||||
if err := json.Unmarshal(resp.Body, &deployResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse deployment response: %w", err)
|
||||
}
|
||||
|
||||
return &deployResp, nil
|
||||
}
|
||||
|
||||
// ListDeployments retrieves a list of deployments with optional filtering
|
||||
func (c *Client) ListDeployments(opts *ListDeploymentsOptions) (*DeploymentList, error) {
|
||||
params := make(map[string]interface{})
|
||||
|
||||
if opts != nil {
|
||||
params["status"] = opts.Status
|
||||
params["location_id"] = opts.LocationID
|
||||
params["page"] = opts.Page
|
||||
params["page_size"] = opts.PageSize
|
||||
params["sort_by"] = opts.SortBy
|
||||
params["sort_order"] = opts.SortOrder
|
||||
}
|
||||
|
||||
endpoint := "/deployments" + buildQueryParams(params)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list deployments: %w", err)
|
||||
}
|
||||
|
||||
var deploymentList DeploymentList
|
||||
if err := decodeData(resp.Body, &deploymentList); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse deployments list: %w", err)
|
||||
}
|
||||
|
||||
deploymentList.Deployments = lo.Map(deploymentList.Deployments, func(deployment Deployment, _ int) Deployment {
|
||||
deployment.GPUCount = deployment.HardwareQuantity
|
||||
deployment.Replicas = deployment.HardwareQuantity // Assuming 1:1 mapping for now
|
||||
return deployment
|
||||
})
|
||||
|
||||
return &deploymentList, nil
|
||||
}
|
||||
|
||||
// GetDeployment retrieves detailed information about a specific deployment
|
||||
func (c *Client) GetDeployment(deploymentID string) (*DeploymentDetail, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s", deploymentID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get deployment details: %w", err)
|
||||
}
|
||||
|
||||
var deploymentDetail DeploymentDetail
|
||||
if err := decodeDataWithFlexibleTimes(resp.Body, &deploymentDetail); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse deployment details: %w", err)
|
||||
}
|
||||
|
||||
return &deploymentDetail, nil
|
||||
}
|
||||
|
||||
// UpdateDeployment updates the configuration of an existing deployment
|
||||
func (c *Client) UpdateDeployment(deploymentID string, req *UpdateDeploymentRequest) (*UpdateDeploymentResponse, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("update request cannot be nil")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s", deploymentID)
|
||||
|
||||
resp, err := c.makeRequest("PATCH", endpoint, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update deployment: %w", err)
|
||||
}
|
||||
|
||||
// API returns direct format:
|
||||
// {"status": "string", "deployment_id": "..."}
|
||||
var updateResp UpdateDeploymentResponse
|
||||
if err := json.Unmarshal(resp.Body, &updateResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update deployment response: %w", err)
|
||||
}
|
||||
|
||||
return &updateResp, nil
|
||||
}
|
||||
|
||||
// ExtendDeployment extends the duration of an existing deployment
|
||||
func (c *Client) ExtendDeployment(deploymentID string, req *ExtendDurationRequest) (*DeploymentDetail, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("extend request cannot be nil")
|
||||
}
|
||||
if req.DurationHours < 1 {
|
||||
return nil, fmt.Errorf("duration_hours must be at least 1")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s/extend", deploymentID)
|
||||
|
||||
resp, err := c.makeRequest("POST", endpoint, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extend deployment: %w", err)
|
||||
}
|
||||
|
||||
var deploymentDetail DeploymentDetail
|
||||
if err := decodeDataWithFlexibleTimes(resp.Body, &deploymentDetail); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse extended deployment details: %w", err)
|
||||
}
|
||||
|
||||
return &deploymentDetail, nil
|
||||
}
|
||||
|
||||
// DeleteDeployment deletes an active deployment
|
||||
func (c *Client) DeleteDeployment(deploymentID string) (*UpdateDeploymentResponse, error) {
|
||||
if deploymentID == "" {
|
||||
return nil, fmt.Errorf("deployment ID cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/deployment/%s", deploymentID)
|
||||
|
||||
resp, err := c.makeRequest("DELETE", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete deployment: %w", err)
|
||||
}
|
||||
|
||||
// API returns direct format:
|
||||
// {"status": "string", "deployment_id": "..."}
|
||||
var deleteResp UpdateDeploymentResponse
|
||||
if err := json.Unmarshal(resp.Body, &deleteResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse delete deployment response: %w", err)
|
||||
}
|
||||
|
||||
return &deleteResp, nil
|
||||
}
|
||||
|
||||
// GetPriceEstimation calculates the estimated cost for a deployment
|
||||
func (c *Client) GetPriceEstimation(req *PriceEstimationRequest) (*PriceEstimationResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("price estimation request cannot be nil")
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if len(req.LocationIDs) == 0 {
|
||||
return nil, fmt.Errorf("location_ids is required")
|
||||
}
|
||||
if req.HardwareID == 0 {
|
||||
return nil, fmt.Errorf("hardware_id is required")
|
||||
}
|
||||
if req.ReplicaCount < 1 {
|
||||
return nil, fmt.Errorf("replica_count must be at least 1")
|
||||
}
|
||||
|
||||
currency := strings.TrimSpace(req.Currency)
|
||||
if currency == "" {
|
||||
currency = "usdc"
|
||||
}
|
||||
|
||||
durationType := strings.TrimSpace(req.DurationType)
|
||||
if durationType == "" {
|
||||
durationType = "hour"
|
||||
}
|
||||
durationType = strings.ToLower(durationType)
|
||||
|
||||
apiDurationType := ""
|
||||
|
||||
durationQty := req.DurationQty
|
||||
if durationQty < 1 {
|
||||
durationQty = req.DurationHours
|
||||
}
|
||||
if durationQty < 1 {
|
||||
return nil, fmt.Errorf("duration_qty must be at least 1")
|
||||
}
|
||||
|
||||
hardwareQty := req.HardwareQty
|
||||
if hardwareQty < 1 {
|
||||
hardwareQty = req.GPUsPerContainer
|
||||
}
|
||||
if hardwareQty < 1 {
|
||||
return nil, fmt.Errorf("hardware_qty must be at least 1")
|
||||
}
|
||||
|
||||
durationHoursForRate := req.DurationHours
|
||||
if durationHoursForRate < 1 {
|
||||
durationHoursForRate = durationQty
|
||||
}
|
||||
switch durationType {
|
||||
case "hour", "hours", "hourly":
|
||||
durationHoursForRate = durationQty
|
||||
apiDurationType = "hourly"
|
||||
case "day", "days", "daily":
|
||||
durationHoursForRate = durationQty * 24
|
||||
apiDurationType = "daily"
|
||||
case "week", "weeks", "weekly":
|
||||
durationHoursForRate = durationQty * 24 * 7
|
||||
apiDurationType = "weekly"
|
||||
case "month", "months", "monthly":
|
||||
durationHoursForRate = durationQty * 24 * 30
|
||||
apiDurationType = "monthly"
|
||||
}
|
||||
if durationHoursForRate < 1 {
|
||||
durationHoursForRate = 1
|
||||
}
|
||||
if apiDurationType == "" {
|
||||
apiDurationType = "hourly"
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"location_ids": req.LocationIDs,
|
||||
"hardware_id": req.HardwareID,
|
||||
"hardware_qty": hardwareQty,
|
||||
"gpus_per_container": req.GPUsPerContainer,
|
||||
"duration_type": apiDurationType,
|
||||
"duration_qty": durationQty,
|
||||
"duration_hours": req.DurationHours,
|
||||
"replica_count": req.ReplicaCount,
|
||||
"currency": currency,
|
||||
}
|
||||
|
||||
endpoint := "/price" + buildQueryParams(params)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get price estimation: %w", err)
|
||||
}
|
||||
|
||||
// Parse according to the actual API response format from docs:
|
||||
// {
|
||||
// "data": {
|
||||
// "replica_count": 0,
|
||||
// "gpus_per_container": 0,
|
||||
// "available_replica_count": [0],
|
||||
// "discount": 0,
|
||||
// "ionet_fee": 0,
|
||||
// "ionet_fee_percent": 0,
|
||||
// "currency_conversion_fee": 0,
|
||||
// "currency_conversion_fee_percent": 0,
|
||||
// "total_cost_usdc": 0
|
||||
// }
|
||||
// }
|
||||
var pricingData struct {
|
||||
ReplicaCount int `json:"replica_count"`
|
||||
GPUsPerContainer int `json:"gpus_per_container"`
|
||||
AvailableReplicaCount []int `json:"available_replica_count"`
|
||||
Discount float64 `json:"discount"`
|
||||
IonetFee float64 `json:"ionet_fee"`
|
||||
IonetFeePercent float64 `json:"ionet_fee_percent"`
|
||||
CurrencyConversionFee float64 `json:"currency_conversion_fee"`
|
||||
CurrencyConversionFeePercent float64 `json:"currency_conversion_fee_percent"`
|
||||
TotalCostUSDC float64 `json:"total_cost_usdc"`
|
||||
}
|
||||
|
||||
if err := decodeData(resp.Body, &pricingData); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse price estimation response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to our internal format
|
||||
durationHoursFloat := float64(durationHoursForRate)
|
||||
if durationHoursFloat <= 0 {
|
||||
durationHoursFloat = 1
|
||||
}
|
||||
|
||||
priceResp := &PriceEstimationResponse{
|
||||
EstimatedCost: pricingData.TotalCostUSDC,
|
||||
Currency: strings.ToUpper(currency),
|
||||
EstimationValid: true,
|
||||
PriceBreakdown: PriceBreakdown{
|
||||
ComputeCost: pricingData.TotalCostUSDC - pricingData.IonetFee - pricingData.CurrencyConversionFee,
|
||||
TotalCost: pricingData.TotalCostUSDC,
|
||||
HourlyRate: pricingData.TotalCostUSDC / durationHoursFloat,
|
||||
},
|
||||
}
|
||||
|
||||
return priceResp, nil
|
||||
}
|
||||
|
||||
// CheckClusterNameAvailability checks if a cluster name is available
|
||||
func (c *Client) CheckClusterNameAvailability(clusterName string) (bool, error) {
|
||||
if clusterName == "" {
|
||||
return false, fmt.Errorf("cluster name cannot be empty")
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"cluster_name": clusterName,
|
||||
}
|
||||
|
||||
endpoint := "/clusters/check_cluster_name_availability" + buildQueryParams(params)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check cluster name availability: %w", err)
|
||||
}
|
||||
|
||||
var availabilityResp bool
|
||||
if err := json.Unmarshal(resp.Body, &availabilityResp); err != nil {
|
||||
return false, fmt.Errorf("failed to parse cluster name availability response: %w", err)
|
||||
}
|
||||
|
||||
return availabilityResp, nil
|
||||
}
|
||||
|
||||
// UpdateClusterName updates the name of an existing cluster/deployment
|
||||
func (c *Client) UpdateClusterName(clusterID string, req *UpdateClusterNameRequest) (*UpdateClusterNameResponse, error) {
|
||||
if clusterID == "" {
|
||||
return nil, fmt.Errorf("cluster ID cannot be empty")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("update cluster name request cannot be nil")
|
||||
}
|
||||
if req.Name == "" {
|
||||
return nil, fmt.Errorf("cluster name cannot be empty")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/clusters/%s/update-name", clusterID)
|
||||
|
||||
resp, err := c.makeRequest("PUT", endpoint, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update cluster name: %w", err)
|
||||
}
|
||||
|
||||
// Parse the response directly without data wrapper based on API docs
|
||||
var updateResp UpdateClusterNameResponse
|
||||
if err := json.Unmarshal(resp.Body, &updateResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update cluster name response: %w", err)
|
||||
}
|
||||
|
||||
return &updateResp, nil
|
||||
}
|
||||
202
pkg/ionet/hardware.go
Normal file
202
pkg/ionet/hardware.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package ionet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// GetAvailableReplicas retrieves available replicas per location for specified hardware
|
||||
func (c *Client) GetAvailableReplicas(hardwareID int, gpuCount int) (*AvailableReplicasResponse, error) {
|
||||
if hardwareID <= 0 {
|
||||
return nil, fmt.Errorf("hardware_id must be greater than 0")
|
||||
}
|
||||
if gpuCount < 1 {
|
||||
return nil, fmt.Errorf("gpu_count must be at least 1")
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"hardware_id": hardwareID,
|
||||
"hardware_qty": gpuCount,
|
||||
}
|
||||
|
||||
endpoint := "/available-replicas" + buildQueryParams(params)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get available replicas: %w", err)
|
||||
}
|
||||
|
||||
type availableReplicaPayload struct {
|
||||
ID int `json:"id"`
|
||||
ISO2 string `json:"iso2"`
|
||||
Name string `json:"name"`
|
||||
AvailableReplicas int `json:"available_replicas"`
|
||||
}
|
||||
var payload []availableReplicaPayload
|
||||
|
||||
if err := decodeData(resp.Body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse available replicas response: %w", err)
|
||||
}
|
||||
|
||||
replicas := lo.Map(payload, func(item availableReplicaPayload, _ int) AvailableReplica {
|
||||
return AvailableReplica{
|
||||
LocationID: item.ID,
|
||||
LocationName: item.Name,
|
||||
HardwareID: hardwareID,
|
||||
HardwareName: "",
|
||||
AvailableCount: item.AvailableReplicas,
|
||||
MaxGPUs: gpuCount,
|
||||
}
|
||||
})
|
||||
|
||||
return &AvailableReplicasResponse{Replicas: replicas}, nil
|
||||
}
|
||||
|
||||
// GetMaxGPUsPerContainer retrieves the maximum number of GPUs available per hardware type
|
||||
func (c *Client) GetMaxGPUsPerContainer() (*MaxGPUResponse, error) {
|
||||
resp, err := c.makeRequest("GET", "/hardware/max-gpus-per-container", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get max GPUs per container: %w", err)
|
||||
}
|
||||
|
||||
var maxGPUResp MaxGPUResponse
|
||||
if err := decodeData(resp.Body, &maxGPUResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse max GPU response: %w", err)
|
||||
}
|
||||
|
||||
return &maxGPUResp, nil
|
||||
}
|
||||
|
||||
// ListHardwareTypes retrieves available hardware types using the max GPUs endpoint
|
||||
func (c *Client) ListHardwareTypes() ([]HardwareType, int, error) {
|
||||
maxGPUResp, err := c.GetMaxGPUsPerContainer()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to list hardware types: %w", err)
|
||||
}
|
||||
|
||||
mapped := lo.Map(maxGPUResp.Hardware, func(hw MaxGPUInfo, _ int) HardwareType {
|
||||
name := strings.TrimSpace(hw.HardwareName)
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("Hardware %d", hw.HardwareID)
|
||||
}
|
||||
|
||||
return HardwareType{
|
||||
ID: hw.HardwareID,
|
||||
Name: name,
|
||||
GPUType: "",
|
||||
GPUMemory: 0,
|
||||
MaxGPUs: hw.MaxGPUsPerContainer,
|
||||
CPU: "",
|
||||
Memory: 0,
|
||||
Storage: 0,
|
||||
HourlyRate: 0,
|
||||
Available: hw.Available > 0,
|
||||
BrandName: strings.TrimSpace(hw.BrandName),
|
||||
AvailableCount: hw.Available,
|
||||
}
|
||||
})
|
||||
|
||||
totalAvailable := maxGPUResp.Total
|
||||
if totalAvailable == 0 {
|
||||
totalAvailable = lo.SumBy(maxGPUResp.Hardware, func(hw MaxGPUInfo) int {
|
||||
return hw.Available
|
||||
})
|
||||
}
|
||||
|
||||
return mapped, totalAvailable, nil
|
||||
}
|
||||
|
||||
// ListLocations retrieves available deployment locations (if supported by the API)
|
||||
func (c *Client) ListLocations() (*LocationsResponse, error) {
|
||||
resp, err := c.makeRequest("GET", "/locations", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list locations: %w", err)
|
||||
}
|
||||
|
||||
var locations LocationsResponse
|
||||
if err := decodeData(resp.Body, &locations); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse locations response: %w", err)
|
||||
}
|
||||
|
||||
locations.Locations = lo.Map(locations.Locations, func(location Location, _ int) Location {
|
||||
location.ISO2 = strings.ToUpper(strings.TrimSpace(location.ISO2))
|
||||
return location
|
||||
})
|
||||
|
||||
if locations.Total == 0 {
|
||||
locations.Total = lo.SumBy(locations.Locations, func(location Location) int {
|
||||
return location.Available
|
||||
})
|
||||
}
|
||||
|
||||
return &locations, nil
|
||||
}
|
||||
|
||||
// GetHardwareType retrieves details about a specific hardware type
|
||||
func (c *Client) GetHardwareType(hardwareID int) (*HardwareType, error) {
|
||||
if hardwareID <= 0 {
|
||||
return nil, fmt.Errorf("hardware ID must be greater than 0")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/hardware/types/%d", hardwareID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get hardware type: %w", err)
|
||||
}
|
||||
|
||||
// API response format not documented, assuming direct format
|
||||
var hardwareType HardwareType
|
||||
if err := json.Unmarshal(resp.Body, &hardwareType); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse hardware type: %w", err)
|
||||
}
|
||||
|
||||
return &hardwareType, nil
|
||||
}
|
||||
|
||||
// GetLocation retrieves details about a specific location
|
||||
func (c *Client) GetLocation(locationID int) (*Location, error) {
|
||||
if locationID <= 0 {
|
||||
return nil, fmt.Errorf("location ID must be greater than 0")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/locations/%d", locationID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get location: %w", err)
|
||||
}
|
||||
|
||||
// API response format not documented, assuming direct format
|
||||
var location Location
|
||||
if err := json.Unmarshal(resp.Body, &location); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse location: %w", err)
|
||||
}
|
||||
|
||||
return &location, nil
|
||||
}
|
||||
|
||||
// GetLocationAvailability retrieves real-time availability for a specific location
|
||||
func (c *Client) GetLocationAvailability(locationID int) (*LocationAvailability, error) {
|
||||
if locationID <= 0 {
|
||||
return nil, fmt.Errorf("location ID must be greater than 0")
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("/locations/%d/availability", locationID)
|
||||
|
||||
resp, err := c.makeRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get location availability: %w", err)
|
||||
}
|
||||
|
||||
// API response format not documented, assuming direct format
|
||||
var availability LocationAvailability
|
||||
if err := json.Unmarshal(resp.Body, &availability); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse location availability: %w", err)
|
||||
}
|
||||
|
||||
return &availability, nil
|
||||
}
|
||||
96
pkg/ionet/jsonutil.go
Normal file
96
pkg/ionet/jsonutil.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package ionet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
// decodeWithFlexibleTimes unmarshals API responses while tolerating timestamp strings
|
||||
// that omit timezone information by normalizing them to RFC3339Nano.
|
||||
func decodeWithFlexibleTimes(data []byte, target interface{}) error {
|
||||
var intermediate interface{}
|
||||
if err := json.Unmarshal(data, &intermediate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
normalized := normalizeTimeValues(intermediate)
|
||||
reencoded, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return json.Unmarshal(reencoded, target)
|
||||
}
|
||||
|
||||
func decodeData[T any](data []byte, target *T) error {
|
||||
var wrapper struct {
|
||||
Data T `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &wrapper); err != nil {
|
||||
return err
|
||||
}
|
||||
*target = wrapper.Data
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeDataWithFlexibleTimes[T any](data []byte, target *T) error {
|
||||
var wrapper struct {
|
||||
Data T `json:"data"`
|
||||
}
|
||||
if err := decodeWithFlexibleTimes(data, &wrapper); err != nil {
|
||||
return err
|
||||
}
|
||||
*target = wrapper.Data
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeTimeValues(value interface{}) interface{} {
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
return lo.MapValues(v, func(val interface{}, _ string) interface{} {
|
||||
return normalizeTimeValues(val)
|
||||
})
|
||||
case []interface{}:
|
||||
return lo.Map(v, func(item interface{}, _ int) interface{} {
|
||||
return normalizeTimeValues(item)
|
||||
})
|
||||
case string:
|
||||
if normalized, changed := normalizeTimeString(v); changed {
|
||||
return normalized
|
||||
}
|
||||
return v
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeTimeString(input string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed == "" {
|
||||
return input, false
|
||||
}
|
||||
|
||||
if _, err := time.Parse(time.RFC3339Nano, trimmed); err == nil {
|
||||
return trimmed, trimmed != input
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, trimmed); err == nil {
|
||||
return trimmed, trimmed != input
|
||||
}
|
||||
|
||||
layouts := []string{
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999",
|
||||
"2006-01-02T15:04:05",
|
||||
}
|
||||
|
||||
for _, layout := range layouts {
|
||||
if parsed, err := time.Parse(layout, trimmed); err == nil {
|
||||
return parsed.UTC().Format(time.RFC3339Nano), true
|
||||
}
|
||||
}
|
||||
|
||||
return input, false
|
||||
}
|
||||
353
pkg/ionet/types.go
Normal file
353
pkg/ionet/types.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package ionet
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client represents the IO.NET API client
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
|
||||
// HTTPClient interface for making HTTP requests
|
||||
type HTTPClient interface {
|
||||
Do(req *HTTPRequest) (*HTTPResponse, error)
|
||||
}
|
||||
|
||||
// HTTPRequest represents an HTTP request
|
||||
type HTTPRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// HTTPResponse represents an HTTP response
|
||||
type HTTPResponse struct {
|
||||
StatusCode int
|
||||
Headers map[string]string
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// DeploymentRequest represents a container deployment request
|
||||
type DeploymentRequest struct {
|
||||
ResourcePrivateName string `json:"resource_private_name"`
|
||||
DurationHours int `json:"duration_hours"`
|
||||
GPUsPerContainer int `json:"gpus_per_container"`
|
||||
HardwareID int `json:"hardware_id"`
|
||||
LocationIDs []int `json:"location_ids"`
|
||||
ContainerConfig ContainerConfig `json:"container_config"`
|
||||
RegistryConfig RegistryConfig `json:"registry_config"`
|
||||
}
|
||||
|
||||
// ContainerConfig represents container configuration
|
||||
type ContainerConfig struct {
|
||||
ReplicaCount int `json:"replica_count"`
|
||||
EnvVariables map[string]string `json:"env_variables,omitempty"`
|
||||
SecretEnvVariables map[string]string `json:"secret_env_variables,omitempty"`
|
||||
Entrypoint []string `json:"entrypoint,omitempty"`
|
||||
TrafficPort int `json:"traffic_port,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
}
|
||||
|
||||
// RegistryConfig represents registry configuration
|
||||
type RegistryConfig struct {
|
||||
ImageURL string `json:"image_url"`
|
||||
RegistryUsername string `json:"registry_username,omitempty"`
|
||||
RegistrySecret string `json:"registry_secret,omitempty"`
|
||||
}
|
||||
|
||||
// DeploymentResponse represents the response from deployment creation
|
||||
type DeploymentResponse struct {
|
||||
DeploymentID string `json:"deployment_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// DeploymentDetail represents detailed deployment information
|
||||
type DeploymentDetail struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
FinishedAt *time.Time `json:"finished_at,omitempty"`
|
||||
AmountPaid float64 `json:"amount_paid"`
|
||||
CompletedPercent float64 `json:"completed_percent"`
|
||||
TotalGPUs int `json:"total_gpus"`
|
||||
GPUsPerContainer int `json:"gpus_per_container"`
|
||||
TotalContainers int `json:"total_containers"`
|
||||
HardwareName string `json:"hardware_name"`
|
||||
HardwareID int `json:"hardware_id"`
|
||||
Locations []DeploymentLocation `json:"locations"`
|
||||
BrandName string `json:"brand_name"`
|
||||
ComputeMinutesServed int `json:"compute_minutes_served"`
|
||||
ComputeMinutesRemaining int `json:"compute_minutes_remaining"`
|
||||
ContainerConfig DeploymentContainerConfig `json:"container_config"`
|
||||
}
|
||||
|
||||
// DeploymentLocation represents a location in deployment details
|
||||
type DeploymentLocation struct {
|
||||
ID int `json:"id"`
|
||||
ISO2 string `json:"iso2"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// DeploymentContainerConfig represents container config in deployment details
|
||||
type DeploymentContainerConfig struct {
|
||||
Entrypoint []string `json:"entrypoint"`
|
||||
EnvVariables map[string]interface{} `json:"env_variables"`
|
||||
TrafficPort int `json:"traffic_port"`
|
||||
ImageURL string `json:"image_url"`
|
||||
}
|
||||
|
||||
// Container represents a container within a deployment
|
||||
type Container struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
ContainerID string `json:"container_id"`
|
||||
Hardware string `json:"hardware"`
|
||||
BrandName string `json:"brand_name"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UptimePercent int `json:"uptime_percent"`
|
||||
GPUsPerContainer int `json:"gpus_per_container"`
|
||||
Status string `json:"status"`
|
||||
ContainerEvents []ContainerEvent `json:"container_events"`
|
||||
PublicURL string `json:"public_url"`
|
||||
}
|
||||
|
||||
// ContainerEvent represents a container event
|
||||
type ContainerEvent struct {
|
||||
Time time.Time `json:"time"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ContainerList represents a list of containers
|
||||
type ContainerList struct {
|
||||
Total int `json:"total"`
|
||||
Workers []Container `json:"workers"`
|
||||
}
|
||||
|
||||
// Deployment represents a deployment in the list
|
||||
type Deployment struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Name string `json:"name"`
|
||||
CompletedPercent float64 `json:"completed_percent"`
|
||||
HardwareQuantity int `json:"hardware_quantity"`
|
||||
BrandName string `json:"brand_name"`
|
||||
HardwareName string `json:"hardware_name"`
|
||||
Served string `json:"served"`
|
||||
Remaining string `json:"remaining"`
|
||||
ComputeMinutesServed int `json:"compute_minutes_served"`
|
||||
ComputeMinutesRemaining int `json:"compute_minutes_remaining"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
GPUCount int `json:"-"` // Derived from HardwareQuantity
|
||||
Replicas int `json:"-"` // Derived from HardwareQuantity
|
||||
}
|
||||
|
||||
// DeploymentList represents a list of deployments with pagination
|
||||
type DeploymentList struct {
|
||||
Deployments []Deployment `json:"deployments"`
|
||||
Total int `json:"total"`
|
||||
Statuses []string `json:"statuses"`
|
||||
}
|
||||
|
||||
// AvailableReplica represents replica availability for a location
|
||||
type AvailableReplica struct {
|
||||
LocationID int `json:"location_id"`
|
||||
LocationName string `json:"location_name"`
|
||||
HardwareID int `json:"hardware_id"`
|
||||
HardwareName string `json:"hardware_name"`
|
||||
AvailableCount int `json:"available_count"`
|
||||
MaxGPUs int `json:"max_gpus"`
|
||||
}
|
||||
|
||||
// AvailableReplicasResponse represents the response for available replicas
|
||||
type AvailableReplicasResponse struct {
|
||||
Replicas []AvailableReplica `json:"replicas"`
|
||||
}
|
||||
|
||||
// MaxGPUResponse represents the response for maximum GPUs per container
|
||||
type MaxGPUResponse struct {
|
||||
Hardware []MaxGPUInfo `json:"hardware"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// MaxGPUInfo represents max GPU information for a hardware type
|
||||
type MaxGPUInfo struct {
|
||||
MaxGPUsPerContainer int `json:"max_gpus_per_container"`
|
||||
Available int `json:"available"`
|
||||
HardwareID int `json:"hardware_id"`
|
||||
HardwareName string `json:"hardware_name"`
|
||||
BrandName string `json:"brand_name"`
|
||||
}
|
||||
|
||||
// PriceEstimationRequest represents a price estimation request
|
||||
type PriceEstimationRequest struct {
|
||||
LocationIDs []int `json:"location_ids"`
|
||||
HardwareID int `json:"hardware_id"`
|
||||
GPUsPerContainer int `json:"gpus_per_container"`
|
||||
DurationHours int `json:"duration_hours"`
|
||||
ReplicaCount int `json:"replica_count"`
|
||||
Currency string `json:"currency"`
|
||||
DurationType string `json:"duration_type"`
|
||||
DurationQty int `json:"duration_qty"`
|
||||
HardwareQty int `json:"hardware_qty"`
|
||||
}
|
||||
|
||||
// PriceEstimationResponse represents the price estimation response
|
||||
type PriceEstimationResponse struct {
|
||||
EstimatedCost float64 `json:"estimated_cost"`
|
||||
Currency string `json:"currency"`
|
||||
PriceBreakdown PriceBreakdown `json:"price_breakdown"`
|
||||
EstimationValid bool `json:"estimation_valid"`
|
||||
}
|
||||
|
||||
// PriceBreakdown represents detailed cost breakdown
|
||||
type PriceBreakdown struct {
|
||||
ComputeCost float64 `json:"compute_cost"`
|
||||
NetworkCost float64 `json:"network_cost,omitempty"`
|
||||
StorageCost float64 `json:"storage_cost,omitempty"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
HourlyRate float64 `json:"hourly_rate"`
|
||||
}
|
||||
|
||||
// ContainerLogs represents container log entries
|
||||
type ContainerLogs struct {
|
||||
ContainerID string `json:"container_id"`
|
||||
Logs []LogEntry `json:"logs"`
|
||||
HasMore bool `json:"has_more"`
|
||||
NextCursor string `json:"next_cursor,omitempty"`
|
||||
}
|
||||
|
||||
// LogEntry represents a single log entry
|
||||
type LogEntry struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Level string `json:"level,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Source string `json:"source,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateDeploymentRequest represents request to update deployment configuration
|
||||
type UpdateDeploymentRequest struct {
|
||||
EnvVariables map[string]string `json:"env_variables,omitempty"`
|
||||
SecretEnvVariables map[string]string `json:"secret_env_variables,omitempty"`
|
||||
Entrypoint []string `json:"entrypoint,omitempty"`
|
||||
TrafficPort *int `json:"traffic_port,omitempty"`
|
||||
ImageURL string `json:"image_url,omitempty"`
|
||||
RegistryUsername string `json:"registry_username,omitempty"`
|
||||
RegistrySecret string `json:"registry_secret,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
|
||||
// ExtendDurationRequest represents request to extend deployment duration
|
||||
type ExtendDurationRequest struct {
|
||||
DurationHours int `json:"duration_hours"`
|
||||
}
|
||||
|
||||
// UpdateDeploymentResponse represents response from deployment update
|
||||
type UpdateDeploymentResponse struct {
|
||||
Status string `json:"status"`
|
||||
DeploymentID string `json:"deployment_id"`
|
||||
}
|
||||
|
||||
// UpdateClusterNameRequest represents request to update cluster name
|
||||
type UpdateClusterNameRequest struct {
|
||||
Name string `json:"cluster_name"`
|
||||
}
|
||||
|
||||
// UpdateClusterNameResponse represents response from cluster name update
|
||||
type UpdateClusterNameResponse struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// APIError represents an API error response
|
||||
type APIError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *APIError) Error() string {
|
||||
if e.Details != "" {
|
||||
return e.Message + ": " + e.Details
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// ListDeploymentsOptions represents options for listing deployments
|
||||
type ListDeploymentsOptions struct {
|
||||
Status string `json:"status,omitempty"` // filter by status
|
||||
LocationID int `json:"location_id,omitempty"` // filter by location
|
||||
Page int `json:"page,omitempty"` // pagination
|
||||
PageSize int `json:"page_size,omitempty"` // pagination
|
||||
SortBy string `json:"sort_by,omitempty"` // sort field
|
||||
SortOrder string `json:"sort_order,omitempty"` // asc/desc
|
||||
}
|
||||
|
||||
// GetLogsOptions represents options for retrieving container logs
|
||||
type GetLogsOptions struct {
|
||||
StartTime *time.Time `json:"start_time,omitempty"`
|
||||
EndTime *time.Time `json:"end_time,omitempty"`
|
||||
Level string `json:"level,omitempty"` // filter by log level
|
||||
Stream string `json:"stream,omitempty"` // filter by stdout/stderr streams
|
||||
Limit int `json:"limit,omitempty"` // max number of log entries
|
||||
Cursor string `json:"cursor,omitempty"` // pagination cursor
|
||||
Follow bool `json:"follow,omitempty"` // stream logs
|
||||
}
|
||||
|
||||
// HardwareType represents a hardware type available for deployment
|
||||
type HardwareType struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
GPUType string `json:"gpu_type"`
|
||||
GPUMemory int `json:"gpu_memory"` // in GB
|
||||
MaxGPUs int `json:"max_gpus"`
|
||||
CPU string `json:"cpu,omitempty"`
|
||||
Memory int `json:"memory,omitempty"` // in GB
|
||||
Storage int `json:"storage,omitempty"` // in GB
|
||||
HourlyRate float64 `json:"hourly_rate"`
|
||||
Available bool `json:"available"`
|
||||
BrandName string `json:"brand_name,omitempty"`
|
||||
AvailableCount int `json:"available_count,omitempty"`
|
||||
}
|
||||
|
||||
// Location represents a deployment location
|
||||
type Location struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ISO2 string `json:"iso2,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
Country string `json:"country,omitempty"`
|
||||
Latitude float64 `json:"latitude,omitempty"`
|
||||
Longitude float64 `json:"longitude,omitempty"`
|
||||
Available int `json:"available,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// LocationsResponse represents the list of locations and aggregated metadata.
|
||||
type LocationsResponse struct {
|
||||
Locations []Location `json:"locations"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// LocationAvailability represents real-time availability for a location
|
||||
type LocationAvailability struct {
|
||||
LocationID int `json:"location_id"`
|
||||
LocationName string `json:"location_name"`
|
||||
Available bool `json:"available"`
|
||||
HardwareAvailability []HardwareAvailability `json:"hardware_availability"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// HardwareAvailability represents availability for specific hardware at a location
|
||||
type HardwareAvailability struct {
|
||||
HardwareID int `json:"hardware_id"`
|
||||
HardwareName string `json:"hardware_name"`
|
||||
AvailableCount int `json:"available_count"`
|
||||
MaxGPUs int `json:"max_gpus"`
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -19,6 +19,22 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
IsSyncImageModel bool
|
||||
}
|
||||
|
||||
var syncModels = []string{
|
||||
"z-image",
|
||||
"qwen-image",
|
||||
"wan2.6",
|
||||
}
|
||||
|
||||
func isSyncImageModel(modelName string) bool {
|
||||
for _, m := range syncModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
@@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
}
|
||||
case constant.RelayModeImagesEdits:
|
||||
if isWanModel(info.OriginModelName) {
|
||||
if isOldWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
|
||||
} else if isWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
}
|
||||
@@ -72,7 +94,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
|
||||
} else {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
@@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
a.IsSyncImageModel = true
|
||||
}
|
||||
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
if isOldWanModel(info.OriginModelName) {
|
||||
return oaiFormEdit2WanxImageEdit(c, info, request)
|
||||
}
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
a.IsSyncImageModel = false
|
||||
} else {
|
||||
a.IsSyncImageModel = true
|
||||
}
|
||||
}
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
@@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
}
|
||||
@@ -150,7 +186,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
// TODO implement me
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
default:
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
err, usage = aliImageHandler(a, c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
if isWanModel(info.OriginModelName) {
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
}
|
||||
err, usage = aliImageHandler(a, c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package ali
|
||||
|
||||
import "github.com/QuantumNous/new-api/dto"
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AliMessage struct {
|
||||
Content any `json:"content"`
|
||||
@@ -65,6 +72,7 @@ type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
@@ -75,14 +83,78 @@ type TaskResult struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []map[string]any `json:"choices,omitempty"`
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []AliMediaContent `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
} `json:"message,omitempty"`
|
||||
} `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||
var imageData []dto.ImageData
|
||||
if len(o.Choices) > 0 {
|
||||
for _, choice := range o.Choices {
|
||||
var data dto.ImageData
|
||||
for _, content := range choice.Message.Content {
|
||||
if content.Image != "" {
|
||||
if strings.HasPrefix(content.Image, "http") {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(content.Image)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
}
|
||||
data.Url = content.Image
|
||||
data.B64Json = b64Json
|
||||
} else {
|
||||
data.B64Json = content.Image
|
||||
}
|
||||
} else if content.Text != "" {
|
||||
data.RevisedPrompt = content.Text
|
||||
}
|
||||
}
|
||||
imageData = append(imageData, data)
|
||||
}
|
||||
}
|
||||
|
||||
return imageData
|
||||
}
|
||||
|
||||
func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||
var imageData []dto.ImageData
|
||||
for _, data := range o.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
} else {
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageData = append(imageData, dto.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
return imageData
|
||||
}
|
||||
|
||||
type AliResponse struct {
|
||||
@@ -92,18 +164,26 @@ type AliResponse struct {
|
||||
}
|
||||
|
||||
type AliImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters AliImageParameters `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AliImageParameters) PromptExtendValue() bool {
|
||||
if p != nil && p.PromptExtend != nil {
|
||||
return *p.PromptExtend
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AliImageInput struct {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -21,17 +20,23 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||
if request.Extra != nil {
|
||||
if val, ok := request.Extra["parameters"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 兼容没有parameters字段的情况,从openai标准字段中提取参数
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
}
|
||||
if val, ok := request.Extra["input"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Input)
|
||||
@@ -41,23 +46,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Parameters == nil {
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
if strings.Contains(request.Model, "z-image") {
|
||||
// z-image 开启prompt_extend后,按2倍计费
|
||||
if imageRequest.Parameters.PromptExtendValue() {
|
||||
info.PriceData.AddOtherRatio("prompt_extend", 2)
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
// 检查n参数
|
||||
if imageRequest.Parameters.N != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
}
|
||||
|
||||
// 同步图片模型和异步图片模型请求格式不一样
|
||||
if isSync {
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Messages: []AliMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []AliMediaContent{
|
||||
{
|
||||
Text: request.Prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
@@ -199,6 +225,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
|
||||
var taskResponse AliResponse
|
||||
var responseBody []byte
|
||||
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
|
||||
for {
|
||||
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||
step++
|
||||
@@ -238,32 +266,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
|
||||
Created: info.StartTime.Unix(),
|
||||
}
|
||||
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
} else {
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
if len(response.Output.Results) > 0 {
|
||||
imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
|
||||
} else if len(response.Output.Choices) > 0 {
|
||||
imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
|
||||
}
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(originBody, &mapResponse)
|
||||
imageResponse.Extra = mapResponse
|
||||
|
||||
imageResponse.Metadata = originBody
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
@@ -282,66 +295,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var (
|
||||
aliResponse *AliResponse
|
||||
originRespBody []byte
|
||||
)
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Message != "" {
|
||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var fullTextResponse dto.ImageResponse
|
||||
if len(aliResponse.Output.Choices) > 0 {
|
||||
fullTextResponse = dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
Data: []dto.ImageData{
|
||||
{
|
||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||
B64Json: "",
|
||||
},
|
||||
},
|
||||
if a.IsSyncImageModel {
|
||||
aliResponse = &aliTaskResponse
|
||||
originRespBody = responseBody
|
||||
} else {
|
||||
// 异步图片模型需要轮询任务结果
|
||||
aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
}
|
||||
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||
fullTextResponse.Extra = mapResponse
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
|
||||
if a.IsSyncImageModel {
|
||||
logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
|
||||
} else {
|
||||
logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
|
||||
}
|
||||
|
||||
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
// 可能生成多张图片,修正计费数量n
|
||||
if aliResponse.Usage.ImageCount != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
||||
} else if len(imageResponses.Data) != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
|
||||
}
|
||||
jsonResponse, err := common.Marshal(imageResponses)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
|
||||
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
|
||||
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
||||
}
|
||||
wanParams := WanImageParameters{
|
||||
//wanParams := WanImageParameters{
|
||||
// N: int(request.N),
|
||||
//}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
N: int(request.N),
|
||||
}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = wanParams
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func isOldWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
|
||||
}
|
||||
|
||||
func isWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan")
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -37,6 +39,13 @@ func getAwsErrorStatusCode(err error) int {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func newAwsInvokeContext() (context.Context, context.CancelFunc) {
|
||||
if common.RelayTimeout <= 0 {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second)
|
||||
}
|
||||
|
||||
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||
var (
|
||||
httpClient *http.Client
|
||||
@@ -117,6 +126,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
a.AwsReq = awsReq
|
||||
return nil, nil
|
||||
} else {
|
||||
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
|
||||
@@ -201,7 +211,10 @@ func getAwsModelID(requestModel string) string {
|
||||
|
||||
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
ctx, cancel := newAwsInvokeContext()
|
||||
defer cancel()
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
statusCode := getAwsErrorStatusCode(err)
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
||||
@@ -228,7 +241,10 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
|
||||
}
|
||||
|
||||
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
ctx, cancel := newAwsInvokeContext()
|
||||
defer cancel()
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
|
||||
if err != nil {
|
||||
statusCode := getAwsErrorStatusCode(err)
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
||||
@@ -268,7 +284,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
ctx, cancel := newAwsInvokeContext()
|
||||
defer cancel()
|
||||
|
||||
awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
|
||||
if err != nil {
|
||||
statusCode := getAwsErrorStatusCode(err)
|
||||
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
|
||||
|
||||
@@ -483,9 +483,11 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
}
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
|
||||
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
}
|
||||
//claudeUsage = &claudeResponse.Usage
|
||||
} else if claudeResponse.Type == "message_stop" {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -137,7 +138,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
} else if baseModel, level := parseThinkingLevelSuffix(info.UpstreamModelName); level != "" {
|
||||
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
info.UpstreamModelName = baseModel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -32,6 +34,7 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true, // support old image/jpeg
|
||||
"image/webp": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
@@ -98,6 +101,7 @@ func clampThinkingBudget(modelName string, budget int) int {
|
||||
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
|
||||
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
|
||||
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
|
||||
// "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
|
||||
func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
isNew25Pro := isNew25ProModel(modelName)
|
||||
is25FlashLite := is25FlashLiteModel(modelName)
|
||||
@@ -118,18 +122,12 @@ func clampThinkingBudgetByEffort(modelName string, effort string) int {
|
||||
maxBudget = maxBudget * 50 / 100
|
||||
case "low":
|
||||
maxBudget = maxBudget * 20 / 100
|
||||
case "minimal":
|
||||
maxBudget = maxBudget * 5 / 100
|
||||
}
|
||||
return clampThinkingBudget(modelName, maxBudget)
|
||||
}
|
||||
|
||||
func parseThinkingLevelSuffix(modelName string) (string, string) {
|
||||
base, level, ok := reasoning.TrimEffortSuffix(modelName)
|
||||
if !ok {
|
||||
return modelName, ""
|
||||
}
|
||||
return base, level
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
@@ -186,7 +184,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
} else if _, level := parseThinkingLevelSuffix(modelName); level != "" {
|
||||
} else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingLevel: level,
|
||||
@@ -379,7 +377,7 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
var system_content []string
|
||||
//shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
if message.Role == "system" {
|
||||
if message.Role == "system" || message.Role == "developer" {
|
||||
system_content = append(system_content, message.StringContent())
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
@@ -677,6 +675,7 @@ func cleanFunctionParameters(params interface{}) interface{} {
|
||||
delete(cleanedMap, "exclusiveMinimum")
|
||||
delete(cleanedMap, "$schema")
|
||||
delete(cleanedMap, "additionalProperties")
|
||||
delete(cleanedMap, "propertyNames")
|
||||
|
||||
// Check and clean 'format' for string types
|
||||
if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
|
||||
@@ -1367,3 +1366,76 @@ func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
type GeminiModelsResponse struct {
|
||||
Models []dto.GeminiModel `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
func FetchGeminiModels(baseURL, apiKey, proxyURL string) ([]string, error) {
|
||||
client, err := service.GetHttpClientWithProxy(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP客户端失败: %v", err)
|
||||
}
|
||||
|
||||
allModels := make([]string, 0)
|
||||
nextPageToken := ""
|
||||
maxPages := 100 // Safety limit to prevent infinite loops
|
||||
|
||||
for page := 0; page < maxPages; page++ {
|
||||
url := fmt.Sprintf("%s/v1beta/models", baseURL)
|
||||
if nextPageToken != "" {
|
||||
url = fmt.Sprintf("%s?pageToken=%s", url, nextPageToken)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
request, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
response.Body.Close()
|
||||
cancel()
|
||||
return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
response.Body.Close()
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
var modelsResponse GeminiModelsResponse
|
||||
if err = common.Unmarshal(body, &modelsResponse); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
for _, model := range modelsResponse.Models {
|
||||
modelNameValue, ok := model.Name.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
modelName := strings.TrimPrefix(modelNameValue, "models/")
|
||||
allModels = append(allModels, modelName)
|
||||
}
|
||||
|
||||
nextPageToken = modelsResponse.NextPageToken
|
||||
if nextPageToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return allModels, nil
|
||||
}
|
||||
|
||||
@@ -14,6 +14,9 @@ var ModelList = []string{
|
||||
"speech-02-turbo",
|
||||
"speech-01-hd",
|
||||
"speech-01-turbo",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.1-lightning",
|
||||
"MiniMax-M2",
|
||||
}
|
||||
|
||||
var ChannelName = "minimax"
|
||||
|
||||
@@ -67,3 +67,40 @@ type OllamaEmbeddingResponse struct {
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaTagsResponse struct {
|
||||
Models []OllamaModel `json:"models"`
|
||||
}
|
||||
|
||||
type OllamaModel struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Details OllamaModelDetail `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaModelDetail struct {
|
||||
ParentModel string `json:"parent_model,omitempty"`
|
||||
Format string `json:"format,omitempty"`
|
||||
Family string `json:"family,omitempty"`
|
||||
Families []string `json:"families,omitempty"`
|
||||
ParameterSize string `json:"parameter_size,omitempty"`
|
||||
QuantizationLevel string `json:"quantization_level,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaPullRequest struct {
|
||||
Name string `json:"name"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaPullResponse struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest,omitempty"`
|
||||
Total int64 `json:"total,omitempty"`
|
||||
Completed int64 `json:"completed,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaDeleteRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
@@ -283,3 +285,246 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
|
||||
service.IOCopyBytesGracefully(c, resp, out)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func FetchOllamaModels(baseURL, apiKey string) ([]OllamaModel, error) {
|
||||
url := fmt.Sprintf("%s/api/tags", baseURL)
|
||||
|
||||
client := &http.Client{}
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
// Ollama 通常不需要 Bearer token,但为了兼容性保留
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tagsResponse OllamaTagsResponse
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
err = common.Unmarshal(body, &tagsResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
return tagsResponse.Models, nil
|
||||
}
|
||||
|
||||
// 拉取 Ollama 模型 (非流式)
|
||||
func PullOllamaModel(baseURL, apiKey, modelName string) error {
|
||||
url := fmt.Sprintf("%s/api/pull", baseURL)
|
||||
|
||||
pullRequest := OllamaPullRequest{
|
||||
Name: modelName,
|
||||
Stream: false, // 非流式,简化处理
|
||||
}
|
||||
|
||||
requestBody, err := common.Marshal(pullRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 30 * 60 * 1000 * time.Millisecond, // 30分钟超时,支持大模型
|
||||
}
|
||||
request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 流式拉取 Ollama 模型 (支持进度回调)
|
||||
func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback func(OllamaPullResponse)) error {
|
||||
url := fmt.Sprintf("%s/api/pull", baseURL)
|
||||
|
||||
pullRequest := OllamaPullRequest{
|
||||
Name: modelName,
|
||||
Stream: true, // 启用流式
|
||||
}
|
||||
|
||||
requestBody, err := common.Marshal(pullRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 60 * 60 * 1000 * time.Millisecond, // 1小时超时,支持超大模型
|
||||
}
|
||||
request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 读取流式响应
|
||||
scanner := bufio.NewScanner(response.Body)
|
||||
successful := false
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var pullResponse OllamaPullResponse
|
||||
if err := common.Unmarshal([]byte(line), &pullResponse); err != nil {
|
||||
continue // 忽略解析失败的行
|
||||
}
|
||||
|
||||
if progressCallback != nil {
|
||||
progressCallback(pullResponse)
|
||||
}
|
||||
|
||||
// 检查是否出现错误或完成
|
||||
if strings.EqualFold(pullResponse.Status, "error") {
|
||||
return fmt.Errorf("拉取模型失败: %s", strings.TrimSpace(line))
|
||||
}
|
||||
if strings.EqualFold(pullResponse.Status, "success") {
|
||||
successful = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取流式响应失败: %v", err)
|
||||
}
|
||||
|
||||
if !successful {
|
||||
return fmt.Errorf("拉取模型未完成: 未收到成功状态")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除 Ollama 模型
|
||||
func DeleteOllamaModel(baseURL, apiKey, modelName string) error {
|
||||
url := fmt.Sprintf("%s/api/delete", baseURL)
|
||||
|
||||
deleteRequest := OllamaDeleteRequest{
|
||||
Name: modelName,
|
||||
}
|
||||
|
||||
requestBody, err := common.Marshal(deleteRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化请求失败: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
request, err := http.NewRequest("DELETE", url, strings.NewReader(string(requestBody)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return fmt.Errorf("删除模型失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchOllamaVersion(baseURL, apiKey string) (string, error) {
|
||||
trimmedBase := strings.TrimRight(baseURL, "/")
|
||||
if trimmedBase == "" {
|
||||
return "", fmt.Errorf("baseURL 为空")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/version", trimmedBase)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建请求失败: %v", err)
|
||||
}
|
||||
|
||||
if apiKey != "" {
|
||||
request.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("请求失败: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取响应失败: %v", err)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("查询版本失败 %d: %s", response.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var versionResp struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &versionResp); err != nil {
|
||||
return "", fmt.Errorf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if versionResp.Version == "" {
|
||||
return "", fmt.Errorf("未返回版本信息")
|
||||
}
|
||||
|
||||
return versionResp.Version, nil
|
||||
}
|
||||
|
||||
234
relay/channel/openai/chat_via_responses.go
Normal file
234
relay/channel/openai/chat_via_responses.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
var responsesResp dto.OpenAIResponsesResponse
|
||||
const maxResponseBodyBytes = 10 << 20 // 10MB
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes+1))
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
||||
}
|
||||
if int64(len(body)) > maxResponseBodyBytes {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("response body exceeds %d bytes", maxResponseBodyBytes), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &responsesResp); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if oaiError := responsesResp.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
|
||||
chatId := helper.GetResponseID(c)
|
||||
chatResp, usage, err := service.ResponsesResponseToChatCompletionsResponse(&responsesResp, chatId)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if usage == nil || usage.TotalTokens == 0 {
|
||||
text := service.ExtractOutputTextFromResponses(&responsesResp)
|
||||
usage = service.ResponseText2Usage(c, text, info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
chatResp.Usage = *usage
|
||||
}
|
||||
|
||||
chatBody, err := common.Marshal(chatResp)
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeJsonMarshalFailed, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, resp, chatBody)
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
if resp == nil || resp.Body == nil {
|
||||
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
responseId := helper.GetResponseID(c)
|
||||
createAt := time.Now().Unix()
|
||||
model := info.UpstreamModelName
|
||||
|
||||
var (
|
||||
usage = &dto.Usage{}
|
||||
textBuilder strings.Builder
|
||||
sentStart bool
|
||||
sentStop bool
|
||||
streamErr *types.NewAPIError
|
||||
)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
if streamErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var streamResp dto.ResponsesStreamResponse
|
||||
if err := common.UnmarshalJsonStr(data, &streamResp); err != nil {
|
||||
logger.LogError(c, "failed to unmarshal responses stream event: "+err.Error())
|
||||
return true
|
||||
}
|
||||
|
||||
switch streamResp.Type {
|
||||
case "response.created":
|
||||
if streamResp.Response != nil {
|
||||
if streamResp.Response.Model != "" {
|
||||
model = streamResp.Response.Model
|
||||
}
|
||||
if streamResp.Response.CreatedAt != 0 {
|
||||
createAt = int64(streamResp.Response.CreatedAt)
|
||||
}
|
||||
}
|
||||
|
||||
case "response.output_text.delta":
|
||||
if !sentStart {
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
sentStart = true
|
||||
}
|
||||
|
||||
if streamResp.Delta != "" {
|
||||
textBuilder.WriteString(streamResp.Delta)
|
||||
delta := streamResp.Delta
|
||||
chunk := &dto.ChatCompletionsStreamResponse{
|
||||
Id: responseId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: createAt,
|
||||
Model: model,
|
||||
Choices: []dto.ChatCompletionsStreamResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Content: &delta,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := helper.ObjectData(c, chunk); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
case "response.completed":
|
||||
if streamResp.Response != nil {
|
||||
if streamResp.Response.Model != "" {
|
||||
model = streamResp.Response.Model
|
||||
}
|
||||
if streamResp.Response.CreatedAt != 0 {
|
||||
createAt = int64(streamResp.Response.CreatedAt)
|
||||
}
|
||||
if streamResp.Response.Usage != nil {
|
||||
if streamResp.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResp.Response.Usage.InputTokens
|
||||
usage.InputTokens = streamResp.Response.Usage.InputTokens
|
||||
}
|
||||
if streamResp.Response.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = streamResp.Response.Usage.OutputTokens
|
||||
usage.OutputTokens = streamResp.Response.Usage.OutputTokens
|
||||
}
|
||||
if streamResp.Response.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = streamResp.Response.Usage.TotalTokens
|
||||
} else {
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
if streamResp.Response.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = streamResp.Response.Usage.InputTokensDetails.CachedTokens
|
||||
usage.PromptTokensDetails.ImageTokens = streamResp.Response.Usage.InputTokensDetails.ImageTokens
|
||||
usage.PromptTokensDetails.AudioTokens = streamResp.Response.Usage.InputTokensDetails.AudioTokens
|
||||
}
|
||||
if streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
|
||||
usage.CompletionTokenDetails.ReasoningTokens = streamResp.Response.Usage.CompletionTokenDetails.ReasoningTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !sentStart {
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
sentStart = true
|
||||
}
|
||||
if !sentStop {
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, "stop")
|
||||
if err := helper.ObjectData(c, stop); err != nil {
|
||||
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
sentStop = true
|
||||
}
|
||||
|
||||
case "response.error", "response.failed":
|
||||
if streamResp.Response != nil {
|
||||
if oaiErr := streamResp.Response.GetOpenAIError(); oaiErr != nil && oaiErr.Type != "" {
|
||||
streamErr = types.WithOpenAIError(*oaiErr, http.StatusInternalServerError)
|
||||
return false
|
||||
}
|
||||
}
|
||||
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
return false
|
||||
|
||||
case "response.output_item.added", "response.output_item.done":
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if streamErr != nil {
|
||||
return nil, streamErr
|
||||
}
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(c, textBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
||||
}
|
||||
|
||||
if !sentStart {
|
||||
if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
if !sentStop {
|
||||
stop := helper.GenerateStopResponse(responseId, createAt, model, "stop")
|
||||
if err := helper.ObjectData(c, stop); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
if info.ShouldIncludeUsage && usage != nil {
|
||||
if err := helper.ObjectData(c, helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)); err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
helper.Done(c)
|
||||
return usage, nil
|
||||
}
|
||||
@@ -208,7 +208,6 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
helper.Done(c)
|
||||
|
||||
case types.RelayFormatClaude:
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
|
||||
common.SysLog("error unmarshalling stream response: " + err.Error())
|
||||
@@ -221,6 +220,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
for _, resp := range claudeResponses {
|
||||
_ = helper.ClaudeData(c, *resp)
|
||||
}
|
||||
info.ClaudeConvertInfo.Done = true
|
||||
|
||||
case types.RelayFormatGemini:
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
|
||||
@@ -186,7 +186,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
applyUsagePostProcessing(info, usage, nil)
|
||||
applyUsagePostProcessing(info, usage, common.StringToByteSlice(lastStreamData))
|
||||
|
||||
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
|
||||
|
||||
@@ -597,6 +597,7 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
@@ -606,6 +607,19 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
case constant.ChannelTypeMoonshot:
|
||||
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
|
||||
if usage.PromptTokensDetails.CachedTokens == 0 {
|
||||
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
|
||||
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
|
||||
usage.PromptTokensDetails.CachedTokens = cachedTokens
|
||||
} else if usage.PromptCacheHitTokens > 0 {
|
||||
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -639,3 +653,32 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
|
||||
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
|
||||
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Usage struct {
|
||||
CachedTokens *int `json:"cached_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(body, &payload); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// 遍历choices查找cached_tokens
|
||||
for _, choice := range payload.Choices {
|
||||
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
|
||||
return *choice.Usage.CachedTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -192,6 +192,10 @@ func sizeToResolution(size string) (string, error) {
|
||||
func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) {
|
||||
otherRatios := make(map[string]float64)
|
||||
aliRatios := map[string]map[string]float64{
|
||||
"wan2.6-i2v": {
|
||||
"720P": 1,
|
||||
"1080P": 1 / 0.6,
|
||||
},
|
||||
"wan2.5-t2v-preview": {
|
||||
"480P": 1,
|
||||
"720P": 2,
|
||||
@@ -287,7 +291,9 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
aliReq.Parameters.Size = "1280*720"
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(req.Model, "wan2.5") {
|
||||
if strings.HasPrefix(req.Model, "wan2.6") {
|
||||
aliReq.Parameters.Resolution = "1080P"
|
||||
} else if strings.HasPrefix(req.Model, "wan2.5") {
|
||||
aliReq.Parameters.Resolution = "1080P"
|
||||
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
|
||||
aliReq.Parameters.Resolution = "720P"
|
||||
|
||||
@@ -346,7 +346,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
taskInfo.Code = resPayload.Code
|
||||
taskInfo.TaskID = resPayload.Data.TaskId
|
||||
taskInfo.Reason = resPayload.Message
|
||||
taskInfo.Reason = resPayload.Data.TaskStatusMsg
|
||||
//任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
|
||||
status := resPayload.Data.TaskStatus
|
||||
switch status {
|
||||
|
||||
@@ -40,6 +40,7 @@ var claudeModelMap = map[string]string{
|
||||
"claude-opus-4-20250514": "claude-opus-4@20250514",
|
||||
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
|
||||
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
|
||||
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
|
||||
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
|
||||
}
|
||||
|
||||
|
||||
@@ -270,6 +270,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
|
||||
case constant.RelayModeRerank:
|
||||
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
|
||||
case constant.RelayModeResponses:
|
||||
return fmt.Sprintf("%s/api/v3/responses", baseUrl), nil
|
||||
case constant.RelayModeAudioSpeech:
|
||||
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
|
||||
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
|
||||
@@ -323,7 +325,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
|
||||
160
relay/chat_completions_via_responses.go
Normal file
160
relay/chat_completions_via_responses.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
openaichannel "github.com/QuantumNous/new-api/relay/channel/openai"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func applySystemPromptIfNeeded(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) {
|
||||
if info == nil || request == nil {
|
||||
return
|
||||
}
|
||||
if info.ChannelSetting.SystemPrompt == "" {
|
||||
return
|
||||
}
|
||||
|
||||
systemRole := request.GetSystemRoleName()
|
||||
|
||||
containSystemPrompt := false
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == systemRole {
|
||||
containSystemPrompt = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containSystemPrompt {
|
||||
systemMessage := dto.Message{
|
||||
Role: systemRole,
|
||||
Content: info.ChannelSetting.SystemPrompt,
|
||||
}
|
||||
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
|
||||
return
|
||||
}
|
||||
|
||||
if !info.ChannelSetting.SystemPromptOverride {
|
||||
return
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
|
||||
for i, message := range request.Messages {
|
||||
if message.Role != systemRole {
|
||||
continue
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
|
||||
return
|
||||
}
|
||||
contents := message.ParseContent()
|
||||
contents = append([]dto.MediaContent{
|
||||
{
|
||||
Type: dto.ContentTypeText,
|
||||
Text: info.ChannelSetting.SystemPrompt,
|
||||
},
|
||||
}, contents...)
|
||||
request.Messages[i].Content = contents
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func chatCompletionsViaResponses(c *gin.Context, info *relaycommon.RelayInfo, adaptor channel.Adaptor, request *dto.GeneralOpenAIRequest) (*dto.Usage, *types.NewAPIError) {
|
||||
overrideCtx := relaycommon.BuildParamOverrideContext(info)
|
||||
chatJSON, err := common.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
chatJSON, err = relaycommon.RemoveDisabledFields(chatJSON, info.ChannelOtherSettings)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if len(info.ParamOverride) > 0 {
|
||||
chatJSON, err = relaycommon.ApplyParamOverride(chatJSON, info.ParamOverride, overrideCtx)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
var overriddenChatReq dto.GeneralOpenAIRequest
|
||||
if err := common.Unmarshal(chatJSON, &overriddenChatReq); err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
responsesReq, err := service.ChatCompletionsRequestToResponsesRequest(&overriddenChatReq)
|
||||
if err != nil {
|
||||
return nil, types.NewErrorWithStatusCode(err, types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
savedRelayMode := info.RelayMode
|
||||
savedRequestURLPath := info.RequestURLPath
|
||||
defer func() {
|
||||
info.RelayMode = savedRelayMode
|
||||
info.RequestURLPath = savedRequestURLPath
|
||||
}()
|
||||
|
||||
info.RelayMode = relayconstant.RelayModeResponses
|
||||
info.RequestURLPath = "/v1/responses"
|
||||
|
||||
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, info, *responsesReq)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
jsonData, err := common.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, info, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, types.NewOpenAIError(nil, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return nil, newApiErr
|
||||
}
|
||||
|
||||
if info.IsStream {
|
||||
usage, newApiErr := openaichannel.OaiResponsesToChatStreamHandler(c, info, httpResp)
|
||||
if newApiErr != nil {
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return nil, newApiErr
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
usage, newApiErr := openaichannel.OaiResponsesToChatHandler(c, info, httpResp)
|
||||
if newApiErr != nil {
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return nil, newApiErr
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
@@ -23,7 +23,7 @@ type ConditionOperation struct {
|
||||
|
||||
type ParamOperation struct {
|
||||
Path string `json:"path"`
|
||||
Mode string `json:"mode"` // delete, set, move, prepend, append
|
||||
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
|
||||
Value interface{} `json:"value"`
|
||||
KeepOrigin bool `json:"keep_origin"`
|
||||
From string `json:"from,omitempty"`
|
||||
@@ -330,8 +330,6 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
}
|
||||
// 处理路径中的负数索引
|
||||
opPath := processNegativeIndex(result, op.Path)
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
|
||||
switch op.Mode {
|
||||
case "delete":
|
||||
@@ -342,11 +340,38 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
|
||||
}
|
||||
result, err = sjson.Set(result, opPath, op.Value)
|
||||
case "move":
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
result, err = moveValue(result, opFrom, opTo)
|
||||
case "copy":
|
||||
if op.From == "" || op.To == "" {
|
||||
return "", fmt.Errorf("copy from/to is required")
|
||||
}
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
result, err = copyValue(result, opFrom, opTo)
|
||||
case "prepend":
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
|
||||
case "append":
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
|
||||
case "trim_prefix":
|
||||
result, err = trimStringValue(result, opPath, op.Value, true)
|
||||
case "trim_suffix":
|
||||
result, err = trimStringValue(result, opPath, op.Value, false)
|
||||
case "ensure_prefix":
|
||||
result, err = ensureStringAffix(result, opPath, op.Value, true)
|
||||
case "ensure_suffix":
|
||||
result, err = ensureStringAffix(result, opPath, op.Value, false)
|
||||
case "trim_space":
|
||||
result, err = transformStringValue(result, opPath, strings.TrimSpace)
|
||||
case "to_lower":
|
||||
result, err = transformStringValue(result, opPath, strings.ToLower)
|
||||
case "to_upper":
|
||||
result, err = transformStringValue(result, opPath, strings.ToUpper)
|
||||
case "replace":
|
||||
result, err = replaceStringValue(result, opPath, op.From, op.To)
|
||||
case "regex_replace":
|
||||
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
}
|
||||
@@ -369,6 +394,14 @@ func moveValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
return sjson.Delete(result, fromPath)
|
||||
}
|
||||
|
||||
func copyValue(jsonStr, fromPath, toPath string) (string, error) {
|
||||
sourceValue := gjson.Get(jsonStr, fromPath)
|
||||
if !sourceValue.Exists() {
|
||||
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
|
||||
}
|
||||
return sjson.Set(jsonStr, toPath, sourceValue.Value())
|
||||
}
|
||||
|
||||
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
switch {
|
||||
@@ -422,6 +455,88 @@ func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (stri
|
||||
return sjson.Set(jsonStr, path, newStr)
|
||||
}
|
||||
|
||||
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return jsonStr, fmt.Errorf("trim value is required")
|
||||
}
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
|
||||
var newStr string
|
||||
if isPrefix {
|
||||
newStr = strings.TrimPrefix(current.String(), valueStr)
|
||||
} else {
|
||||
newStr = strings.TrimSuffix(current.String(), valueStr)
|
||||
}
|
||||
return sjson.Set(jsonStr, path, newStr)
|
||||
}
|
||||
|
||||
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return jsonStr, fmt.Errorf("ensure value is required")
|
||||
}
|
||||
valueStr := fmt.Sprintf("%v", value)
|
||||
if valueStr == "" {
|
||||
return jsonStr, fmt.Errorf("ensure value is required")
|
||||
}
|
||||
|
||||
currentStr := current.String()
|
||||
if isPrefix {
|
||||
if strings.HasPrefix(currentStr, valueStr) {
|
||||
return jsonStr, nil
|
||||
}
|
||||
return sjson.Set(jsonStr, path, valueStr+currentStr)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(currentStr, valueStr) {
|
||||
return jsonStr, nil
|
||||
}
|
||||
return sjson.Set(jsonStr, path, currentStr+valueStr)
|
||||
}
|
||||
|
||||
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
return sjson.Set(jsonStr, path, transform(current.String()))
|
||||
}
|
||||
|
||||
func replaceStringValue(jsonStr, path, from, to string) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
if from == "" {
|
||||
return jsonStr, fmt.Errorf("replace from is required")
|
||||
}
|
||||
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
|
||||
}
|
||||
|
||||
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
if current.Type != gjson.String {
|
||||
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
|
||||
}
|
||||
if pattern == "" {
|
||||
return jsonStr, fmt.Errorf("regex pattern is required")
|
||||
}
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return jsonStr, err
|
||||
}
|
||||
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
|
||||
}
|
||||
|
||||
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
|
||||
current := gjson.Get(jsonStr, path)
|
||||
var currentMap, newMap map[string]interface{}
|
||||
@@ -455,18 +570,19 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
|
||||
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
||||
// 目前内置以下字段:
|
||||
// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
|
||||
// - upstream_model:始终为通道映射后的上游模型名。
|
||||
// - upstream_model/model:始终为通道映射后的上游模型名。
|
||||
// - original_model:请求最初指定的模型名。
|
||||
// - request_path:请求路径
|
||||
// - is_channel_test:是否为渠道测试请求(同 is_test)。
|
||||
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
if info == nil || info.ChannelMeta == nil {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := make(map[string]interface{})
|
||||
if info.UpstreamModelName != "" {
|
||||
ctx["model"] = info.UpstreamModelName
|
||||
ctx["upstream_model"] = info.UpstreamModelName
|
||||
if info.ChannelMeta != nil && info.ChannelMeta.UpstreamModelName != "" {
|
||||
ctx["model"] = info.ChannelMeta.UpstreamModelName
|
||||
ctx["upstream_model"] = info.ChannelMeta.UpstreamModelName
|
||||
}
|
||||
if info.OriginModelName != "" {
|
||||
ctx["original_model"] = info.OriginModelName
|
||||
@@ -475,8 +591,13 @@ func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
if len(ctx) == 0 {
|
||||
return nil
|
||||
if info.RequestURLPath != "" {
|
||||
requestPath := info.RequestURLPath
|
||||
if requestPath != "" {
|
||||
ctx["request_path"] = requestPath
|
||||
}
|
||||
}
|
||||
|
||||
ctx["is_channel_test"] = info.IsChannelTest
|
||||
return ctx
|
||||
}
|
||||
|
||||
791
relay/common/override_test.go
Normal file
791
relay/common/override_test.go
Normal file
@@ -0,0 +1,791 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
|
||||
// trim_prefix example:
|
||||
// {"operations":[{"path":"model","mode":"trim_prefix","value":"openai/"}]}
|
||||
input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "trim_prefix",
|
||||
"value": "openai/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideTrimSuffix(t *testing.T) {
|
||||
// trim_suffix example:
|
||||
// {"operations":[{"path":"model","mode":"trim_suffix","value":"-latest"}]}
|
||||
input := []byte(`{"model":"gpt-4-latest","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "trim_suffix",
|
||||
"value": "-latest",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideTrimNoop(t *testing.T) {
|
||||
// trim_prefix no-op example:
|
||||
// {"operations":[{"path":"model","mode":"trim_prefix","value":"openai/"}]}
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "trim_prefix",
|
||||
"value": "openai/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideTrimRequiresValue(t *testing.T) {
|
||||
// trim_prefix requires value example:
|
||||
// {"operations":[{"path":"model","mode":"trim_prefix"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "trim_prefix",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideReplace(t *testing.T) {
|
||||
// replace example:
|
||||
// {"operations":[{"path":"model","mode":"replace","from":"openai/","to":""}]}
|
||||
input := []byte(`{"model":"openai/gpt-4o-mini","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "replace",
|
||||
"from": "openai/",
|
||||
"to": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4o-mini","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideRegexReplace(t *testing.T) {
|
||||
// regex_replace example:
|
||||
// {"operations":[{"path":"model","mode":"regex_replace","from":"^gpt-","to":"openai/gpt-"}]}
|
||||
input := []byte(`{"model":"gpt-4o-mini","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "regex_replace",
|
||||
"from": "^gpt-",
|
||||
"to": "openai/gpt-",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"openai/gpt-4o-mini","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideReplaceRequiresFrom(t *testing.T) {
|
||||
// replace requires from example:
|
||||
// {"operations":[{"path":"model","mode":"replace"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "replace",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideRegexReplaceRequiresPattern(t *testing.T) {
|
||||
// regex_replace requires from(pattern) example:
|
||||
// {"operations":[{"path":"model","mode":"regex_replace"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "regex_replace",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideDelete(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "delete",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]interface{}
|
||||
if err := json.Unmarshal(out, &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal output JSON: %v", err)
|
||||
}
|
||||
if _, exists := got["temperature"]; exists {
|
||||
t.Fatalf("expected temperature to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSet(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.1}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideSetKeepOrigin(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"keep_origin": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideMove(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","meta":{"x":1}}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "move",
|
||||
"from": "model",
|
||||
"to": "meta.model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"meta":{"x":1,"model":"gpt-4"}}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideMoveMissingSource(t *testing.T) {
|
||||
input := []byte(`{"meta":{"x":1}}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "move",
|
||||
"from": "model",
|
||||
"to": "meta.model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverridePrependAppendString(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prepend",
|
||||
"value": "openai/",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "append",
|
||||
"value": "-latest",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"openai/gpt-4-latest"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverridePrependAppendArray(t *testing.T) {
|
||||
input := []byte(`{"arr":[1,2]}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "arr",
|
||||
"mode": "prepend",
|
||||
"value": 0,
|
||||
},
|
||||
map[string]interface{}{
|
||||
"path": "arr",
|
||||
"mode": "append",
|
||||
"value": []interface{}{3, 4},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"arr":[0,1,2,3,4]}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideAppendObjectMergeKeepOrigin(t *testing.T) {
|
||||
input := []byte(`{"obj":{"a":1}}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "obj",
|
||||
"mode": "append",
|
||||
"keep_origin": true,
|
||||
"value": map[string]interface{}{
|
||||
"a": 2,
|
||||
"b": 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"obj":{"a":1,"b":3}}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideAppendObjectMergeOverride(t *testing.T) {
|
||||
input := []byte(`{"obj":{"a":1}}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "obj",
|
||||
"mode": "append",
|
||||
"value": map[string]interface{}{
|
||||
"a": 2,
|
||||
"b": 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"obj":{"a":2,"b":3}}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionORDefault(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prefix",
|
||||
"value": "gpt",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prefix",
|
||||
"value": "claude",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.1}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionAND(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"logic": "AND",
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prefix",
|
||||
"value": "gpt",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "gt",
|
||||
"value": 0.5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.1}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionInvert(t *testing.T) {
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prefix",
|
||||
"value": "gpt",
|
||||
"invert": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionPassMissingKey(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prefix",
|
||||
"value": "gpt",
|
||||
"pass_missing_key": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideConditionFromContext(t *testing.T) {
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "temperature",
|
||||
"mode": "set",
|
||||
"value": 0.1,
|
||||
"conditions": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "prefix",
|
||||
"value": "gpt",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := map[string]interface{}{
|
||||
"model": "gpt-4",
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideNegativeIndexPath(t *testing.T) {
|
||||
input := []byte(`{"arr":[{"model":"a"},{"model":"b"}]}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "arr.-1.model",
|
||||
"mode": "set",
|
||||
"value": "c",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"arr":[{"model":"a"},{"model":"c"}]}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideRegexReplaceInvalidPattern(t *testing.T) {
|
||||
// regex_replace invalid pattern example:
|
||||
// {"operations":[{"path":"model","mode":"regex_replace","from":"(","to":"x"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "regex_replace",
|
||||
"from": "(",
|
||||
"to": "x",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideCopy(t *testing.T) {
|
||||
// copy example:
|
||||
// {"operations":[{"mode":"copy","from":"model","to":"original_model"}]}
|
||||
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "copy",
|
||||
"from": "model",
|
||||
"to": "original_model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4","original_model":"gpt-4","temperature":0.7}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideCopyMissingSource(t *testing.T) {
|
||||
// copy missing source example:
|
||||
// {"operations":[{"mode":"copy","from":"model","to":"original_model"}]}
|
||||
input := []byte(`{"temperature":0.7}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "copy",
|
||||
"from": "model",
|
||||
"to": "original_model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideCopyRequiresFromTo(t *testing.T) {
|
||||
// copy requires from/to example:
|
||||
// {"operations":[{"mode":"copy"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"mode": "copy",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideEnsurePrefix(t *testing.T) {
|
||||
// ensure_prefix example:
|
||||
// {"operations":[{"path":"model","mode":"ensure_prefix","value":"openai/"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "ensure_prefix",
|
||||
"value": "openai/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"openai/gpt-4"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideEnsurePrefixNoop(t *testing.T) {
|
||||
// ensure_prefix no-op example:
|
||||
// {"operations":[{"path":"model","mode":"ensure_prefix","value":"openai/"}]}
|
||||
input := []byte(`{"model":"openai/gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "ensure_prefix",
|
||||
"value": "openai/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"openai/gpt-4"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideEnsureSuffix(t *testing.T) {
|
||||
// ensure_suffix example:
|
||||
// {"operations":[{"path":"model","mode":"ensure_suffix","value":"-latest"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "ensure_suffix",
|
||||
"value": "-latest",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4-latest"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideEnsureSuffixNoop(t *testing.T) {
|
||||
// ensure_suffix no-op example:
|
||||
// {"operations":[{"path":"model","mode":"ensure_suffix","value":"-latest"}]}
|
||||
input := []byte(`{"model":"gpt-4-latest"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "ensure_suffix",
|
||||
"value": "-latest",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4-latest"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideEnsureRequiresValue(t *testing.T) {
|
||||
// ensure_prefix requires value example:
|
||||
// {"operations":[{"path":"model","mode":"ensure_prefix"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "ensure_prefix",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := ApplyParamOverride(input, override, nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideTrimSpace(t *testing.T) {
|
||||
// trim_space example:
|
||||
// {"operations":[{"path":"model","mode":"trim_space"}]}
|
||||
input := []byte("{\"model\":\" gpt-4 \\n\"}")
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "trim_space",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideToLower(t *testing.T) {
|
||||
// to_lower example:
|
||||
// {"operations":[{"path":"model","mode":"to_lower"}]}
|
||||
input := []byte(`{"model":"GPT-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "to_lower",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"gpt-4"}`, string(out))
|
||||
}
|
||||
|
||||
func TestApplyParamOverrideToUpper(t *testing.T) {
|
||||
// to_upper example:
|
||||
// {"operations":[{"path":"model","mode":"to_upper"}]}
|
||||
input := []byte(`{"model":"gpt-4"}`)
|
||||
override := map[string]interface{}{
|
||||
"operations": []interface{}{
|
||||
map[string]interface{}{
|
||||
"path": "model",
|
||||
"mode": "to_upper",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := ApplyParamOverride(input, override, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ApplyParamOverride returned error: %v", err)
|
||||
}
|
||||
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
|
||||
}
|
||||
|
||||
func assertJSONEqual(t *testing.T, want, got string) {
|
||||
t.Helper()
|
||||
|
||||
var wantObj interface{}
|
||||
var gotObj interface{}
|
||||
|
||||
if err := json.Unmarshal([]byte(want), &wantObj); err != nil {
|
||||
t.Fatalf("failed to unmarshal want JSON: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(got), &gotObj); err != nil {
|
||||
t.Fatalf("failed to unmarshal got JSON: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(wantObj, gotObj) {
|
||||
t.Fatalf("json not equal\nwant: %s\ngot: %s", want, got)
|
||||
}
|
||||
}
|
||||
@@ -115,6 +115,7 @@ type RelayInfo struct {
|
||||
SendResponseCount int
|
||||
FinalPreConsumedQuota int // 最终预消耗的配额
|
||||
IsClaudeBetaQuery bool // /v1/messages?beta=true
|
||||
IsChannelTest bool // channel test request
|
||||
|
||||
PriceData types.PriceData
|
||||
|
||||
|
||||
@@ -14,10 +14,12 @@ import (
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
@@ -72,6 +74,28 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
adaptor.Init(info)
|
||||
|
||||
if info.RelayMode == relayconstant.RelayModeChatCompletions &&
|
||||
!model_setting.GetGlobalSettings().PassThroughRequestEnabled &&
|
||||
!info.ChannelSetting.PassThroughBodyEnabled &&
|
||||
service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.OriginModelName) {
|
||||
applySystemPromptIfNeeded(c, info, request)
|
||||
usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, request)
|
||||
if newApiErr != nil {
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
var containAudioTokens = usage.CompletionTokenDetails.AudioTokens > 0 || usage.PromptTokensDetails.AudioTokens > 0
|
||||
var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName)
|
||||
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage, "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
@@ -181,22 +205,25 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
return newApiErr
|
||||
}
|
||||
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
var containAudioTokens = usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0
|
||||
var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName)
|
||||
|
||||
if containAudioTokens && containsAudioRatios {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent += "(可能是请求出错)"
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
@@ -246,8 +273,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
// search-preview 模型不支持 response api
|
||||
@@ -258,8 +285,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
// claude web search tool 计费
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
@@ -269,8 +296,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||||
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
|
||||
}
|
||||
// file search tool 计费
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
@@ -281,8 +308,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String()))
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
@@ -290,7 +317,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
@@ -300,14 +327,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
baseTokens := dPromptTokens
|
||||
// 减去 cached tokens
|
||||
// Anthropic API 的 input_tokens 已经不包含缓存 tokens,不需要减去
|
||||
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens,需要减去
|
||||
var cachedTokensWithRatio decimal.Decimal
|
||||
if !dCacheTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
}
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
}
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
@@ -325,7 +358,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
// 重新计算 base tokens
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
@@ -350,17 +383,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
dOtherRatio := decimal.NewFromFloat(otherRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
|
||||
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
//var logContent string
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
@@ -399,15 +440,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
|
||||
@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -110,8 +110,6 @@ func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.Embeddi
|
||||
return embeddingRequest, nil
|
||||
}
|
||||
|
||||
// GetAndValidateResponsesRequest parses the HTTP request body into an OpenAIResponsesRequest and ensures the Model field is provided.
|
||||
// It returns the parsed request, or an error if the body cannot be parsed or the Model is empty.
|
||||
func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
|
||||
request := &dto.OpenAIResponsesRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
@@ -121,6 +119,9 @@ func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest
|
||||
if request.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if request.Input == nil {
|
||||
return nil, errors.New("input is required")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -323,4 +324,4 @@ func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatch
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
quality = "hd"
|
||||
}
|
||||
|
||||
var logContent string
|
||||
var logContent []string
|
||||
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
||||
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
|
||||
}
|
||||
if len(quality) > 0 {
|
||||
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
|
||||
}
|
||||
if request.N > 0 {
|
||||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N))
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -150,6 +150,14 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 auto 分组:从 context 获取实际选中的分组
|
||||
// 当使用 auto 分组时,Distribute 中间件会将实际选中的分组存储在 ContextKeyAutoGroup 中
|
||||
if autoGroup, exists := common.GetContextKey(c, constant.ContextKeyAutoGroup); exists {
|
||||
if groupStr, ok := autoGroup.(string); ok && groupStr != "" {
|
||||
info.UsingGroup = groupStr
|
||||
}
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup)
|
||||
var ratio float64
|
||||
|
||||
@@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -93,6 +93,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.POST("/2fa/enable", controller.Enable2FA)
|
||||
selfRoute.POST("/2fa/disable", controller.Disable2FA)
|
||||
selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
|
||||
|
||||
// Check-in routes
|
||||
selfRoute.GET("/checkin", controller.GetCheckinStatus)
|
||||
selfRoute.POST("/checkin", middleware.TurnstileCheck(), controller.DoCheckin)
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
@@ -152,6 +156,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
channelRoute.POST("/ollama/pull", controller.OllamaPullModel)
|
||||
channelRoute.POST("/ollama/pull/stream", controller.OllamaPullModelStream)
|
||||
channelRoute.DELETE("/ollama/delete", controller.OllamaDeleteModel)
|
||||
channelRoute.GET("/ollama/version/:id", controller.OllamaVersion)
|
||||
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
|
||||
channelRoute.GET("/tag/models", controller.GetTagModels)
|
||||
channelRoute.POST("/copy/:id", controller.CopyChannel)
|
||||
@@ -256,5 +264,31 @@ func SetApiRouter(router *gin.Engine) {
|
||||
modelsRoute.PUT("/", controller.UpdateModelMeta)
|
||||
modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
|
||||
}
|
||||
|
||||
// Deployments (model deployment management)
|
||||
deploymentsRoute := apiRouter.Group("/deployments")
|
||||
deploymentsRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
deploymentsRoute.GET("/settings", controller.GetModelDeploymentSettings)
|
||||
deploymentsRoute.POST("/settings/test-connection", controller.TestIoNetConnection)
|
||||
deploymentsRoute.GET("/", controller.GetAllDeployments)
|
||||
deploymentsRoute.GET("/search", controller.SearchDeployments)
|
||||
deploymentsRoute.POST("/test-connection", controller.TestIoNetConnection)
|
||||
deploymentsRoute.GET("/hardware-types", controller.GetHardwareTypes)
|
||||
deploymentsRoute.GET("/locations", controller.GetLocations)
|
||||
deploymentsRoute.GET("/available-replicas", controller.GetAvailableReplicas)
|
||||
deploymentsRoute.POST("/price-estimation", controller.GetPriceEstimation)
|
||||
deploymentsRoute.GET("/check-name", controller.CheckClusterNameAvailability)
|
||||
deploymentsRoute.POST("/", controller.CreateDeployment)
|
||||
|
||||
deploymentsRoute.GET("/:id", controller.GetDeployment)
|
||||
deploymentsRoute.GET("/:id/logs", controller.GetDeploymentLogs)
|
||||
deploymentsRoute.GET("/:id/containers", controller.ListDeploymentContainers)
|
||||
deploymentsRoute.GET("/:id/containers/:container_id", controller.GetContainerDetails)
|
||||
deploymentsRoute.PUT("/:id", controller.UpdateDeployment)
|
||||
deploymentsRoute.PUT("/:id/name", controller.UpdateDeploymentName)
|
||||
deploymentsRoute.POST("/:id/extend", controller.ExtendDeployment)
|
||||
deploymentsRoute.DELETE("/:id", controller.DeleteDeployment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,9 +57,12 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
|
||||
if types.IsSkipRetryError(err) {
|
||||
return false
|
||||
}
|
||||
if err.StatusCode == http.StatusUnauthorized {
|
||||
if operation_setting.ShouldDisableByStatusCode(err.StatusCode) {
|
||||
return true
|
||||
}
|
||||
//if err.StatusCode == http.StatusUnauthorized {
|
||||
// return true
|
||||
//}
|
||||
if err.StatusCode == http.StatusForbidden {
|
||||
switch channelType {
|
||||
case constant.ChannelTypeGemini:
|
||||
|
||||
@@ -389,25 +389,29 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
|
||||
idx := blockIndex
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: toolCall.ID,
|
||||
Type: "tool_use",
|
||||
Name: toolCall.Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
if toolCall.Function.Name != "" {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_start",
|
||||
ContentBlock: &dto.ClaudeMediaMessage{
|
||||
Id: toolCall.ID,
|
||||
Type: "tool_use",
|
||||
Name: toolCall.Function.Name,
|
||||
Input: map[string]interface{}{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &toolCall.Function.Arguments,
|
||||
},
|
||||
})
|
||||
if len(toolCall.Function.Arguments) > 0 {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Index: &idx,
|
||||
Type: "content_block_delta",
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
Type: "input_json_delta",
|
||||
PartialJson: &toolCall.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
info.ClaudeConvertInfo.Index = blockIndex
|
||||
}
|
||||
@@ -670,20 +674,21 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
|
||||
var tools []dto.ToolCallRequest
|
||||
for _, tool := range geminiRequest.GetTools() {
|
||||
if tool.FunctionDeclarations != nil {
|
||||
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
|
||||
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
|
||||
if ok {
|
||||
for _, function := range functionDeclarations {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: function.Name,
|
||||
Description: function.Description,
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
functionDeclarations, err := common.Any2Type[[]dto.FunctionRequest](tool.FunctionDeclarations)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to parse gemini function declarations: %v (type=%T)", err, tool.FunctionDeclarations))
|
||||
continue
|
||||
}
|
||||
for _, function := range functionDeclarations {
|
||||
openAITool := dto.ToolCallRequest{
|
||||
Type: "function",
|
||||
Function: dto.FunctionRequest{
|
||||
Name: function.Name,
|
||||
Description: function.Description,
|
||||
Parameters: function.Parameters,
|
||||
},
|
||||
}
|
||||
tools = append(tools, openAITool)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,33 +81,26 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
||||
return claudeErr
|
||||
}
|
||||
|
||||
// RelayErrorHandler converts an HTTP error response into a structured types.NewAPIError.
|
||||
// It returns a NewAPIError initialized with the response status code and one of:
|
||||
// - an Err describing an absent or unreadable body,
|
||||
// - an Err containing the unmarshaled error message (or status + raw body when showBodyWhenFail is true), or
|
||||
// - an embedded OpenAI-style error when the response body contains a compatible error object.
|
||||
// The returned NewAPIError's status code reflects resp.StatusCode.
|
||||
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
|
||||
if resp.Body == nil {
|
||||
newApiErr.Err = errors.New("response body is nil")
|
||||
return
|
||||
}
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
CloseResponseBodyGracefully(resp)
|
||||
newApiErr.Err = fmt.Errorf("read response body failed: %w", err)
|
||||
return
|
||||
}
|
||||
CloseResponseBodyGracefully(resp)
|
||||
var errResponse dto.GeneralErrorResponse
|
||||
buildErrWithBody := func(message string) error {
|
||||
if message == "" {
|
||||
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
}
|
||||
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody))
|
||||
}
|
||||
|
||||
err = common.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
if showBodyWhenFail {
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
newApiErr.Err = buildErrWithBody("")
|
||||
} else {
|
||||
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||
@@ -120,10 +113,16 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
|
||||
oaiError := errResponse.TryToOpenAIError()
|
||||
if oaiError != nil {
|
||||
newApiErr = types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
if showBodyWhenFail {
|
||||
newApiErr.Err = buildErrWithBody(newApiErr.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
if showBodyWhenFail {
|
||||
newApiErr.Err = buildErrWithBody(newApiErr.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -169,4 +168,4 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
|
||||
}
|
||||
|
||||
return taskError
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,4 +57,5 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||
if err != nil {
|
||||
logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||
}
|
||||
c.Writer.Flush()
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ func InitHttpClient() {
|
||||
MaxIdleConns: common.RelayMaxIdleConns,
|
||||
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
|
||||
ForceAttemptHTTP2: true,
|
||||
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
|
||||
}
|
||||
|
||||
if common.RelayTimeout == 0 {
|
||||
@@ -81,6 +82,9 @@ func ResetProxyClientCache() {
|
||||
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
|
||||
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
|
||||
if proxyURL == "" {
|
||||
if client := GetHttpClient(); client != nil {
|
||||
return client, nil
|
||||
}
|
||||
return http.DefaultClient, nil
|
||||
}
|
||||
|
||||
|
||||
18
service/openai_chat_responses_compat.go
Normal file
18
service/openai_chat_responses_compat.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/service/openaicompat"
|
||||
)
|
||||
|
||||
func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*dto.OpenAIResponsesRequest, error) {
|
||||
return openaicompat.ChatCompletionsRequestToResponsesRequest(req)
|
||||
}
|
||||
|
||||
func ResponsesResponseToChatCompletionsResponse(resp *dto.OpenAIResponsesResponse, id string) (*dto.OpenAITextResponse, *dto.Usage, error) {
|
||||
return openaicompat.ResponsesResponseToChatCompletionsResponse(resp, id)
|
||||
}
|
||||
|
||||
func ExtractOutputTextFromResponses(resp *dto.OpenAIResponsesResponse) string {
|
||||
return openaicompat.ExtractOutputTextFromResponses(resp)
|
||||
}
|
||||
14
service/openai_chat_responses_mode.go
Normal file
14
service/openai_chat_responses_mode.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/service/openaicompat"
|
||||
"github.com/QuantumNous/new-api/setting/model_setting"
|
||||
)
|
||||
|
||||
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
|
||||
return openaicompat.ShouldChatCompletionsUseResponsesPolicy(policy, channelID, model)
|
||||
}
|
||||
|
||||
func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
|
||||
return openaicompat.ShouldChatCompletionsUseResponsesGlobal(channelID, model)
|
||||
}
|
||||
262
service/openaicompat/chat_to_responses.go
Normal file
262
service/openaicompat/chat_to_responses.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package openaicompat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
func normalizeChatImageURLToString(v any) any {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
return vv
|
||||
case map[string]any:
|
||||
if url := common.Interface2String(vv["url"]); url != "" {
|
||||
return url
|
||||
}
|
||||
return v
|
||||
case dto.MessageImageUrl:
|
||||
if vv.Url != "" {
|
||||
return vv.Url
|
||||
}
|
||||
return v
|
||||
case *dto.MessageImageUrl:
|
||||
if vv != nil && vv.Url != "" {
|
||||
return vv.Url
|
||||
}
|
||||
return v
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func ChatCompletionsRequestToResponsesRequest(req *dto.GeneralOpenAIRequest) (*dto.OpenAIResponsesRequest, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if req.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if req.N > 1 {
|
||||
return nil, fmt.Errorf("n>1 is not supported in responses compatibility mode")
|
||||
}
|
||||
|
||||
var instructionsParts []string
|
||||
inputItems := make([]map[string]any, 0, len(req.Messages))
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
role := strings.TrimSpace(msg.Role)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Prefer mapping system/developer messages into `instructions`.
|
||||
if role == "system" || role == "developer" {
|
||||
if msg.Content == nil {
|
||||
continue
|
||||
}
|
||||
if msg.IsStringContent() {
|
||||
if s := strings.TrimSpace(msg.StringContent()); s != "" {
|
||||
instructionsParts = append(instructionsParts, s)
|
||||
}
|
||||
continue
|
||||
}
|
||||
parts := msg.ParseContent()
|
||||
var sb strings.Builder
|
||||
for _, part := range parts {
|
||||
if part.Type == dto.ContentTypeText && strings.TrimSpace(part.Text) != "" {
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
if s := strings.TrimSpace(sb.String()); s != "" {
|
||||
instructionsParts = append(instructionsParts, s)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
item := map[string]any{
|
||||
"role": role,
|
||||
}
|
||||
|
||||
if msg.Content == nil {
|
||||
item["content"] = ""
|
||||
inputItems = append(inputItems, item)
|
||||
continue
|
||||
}
|
||||
|
||||
if msg.IsStringContent() {
|
||||
item["content"] = msg.StringContent()
|
||||
inputItems = append(inputItems, item)
|
||||
continue
|
||||
}
|
||||
|
||||
parts := msg.ParseContent()
|
||||
contentParts := make([]map[string]any, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
switch part.Type {
|
||||
case dto.ContentTypeText:
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": "input_text",
|
||||
"text": part.Text,
|
||||
})
|
||||
case dto.ContentTypeImageURL:
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": "input_image",
|
||||
"image_url": normalizeChatImageURLToString(part.ImageUrl),
|
||||
})
|
||||
case dto.ContentTypeInputAudio:
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": "input_audio",
|
||||
"input_audio": part.InputAudio,
|
||||
})
|
||||
case dto.ContentTypeFile:
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": "input_file",
|
||||
"file": part.File,
|
||||
})
|
||||
case dto.ContentTypeVideoUrl:
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": "input_video",
|
||||
"video_url": part.VideoUrl,
|
||||
})
|
||||
default:
|
||||
// Best-effort: keep unknown parts as-is to avoid silently dropping context.
|
||||
contentParts = append(contentParts, map[string]any{
|
||||
"type": part.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
item["content"] = contentParts
|
||||
inputItems = append(inputItems, item)
|
||||
}
|
||||
|
||||
inputRaw, err := common.Marshal(inputItems)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var instructionsRaw json.RawMessage
|
||||
if len(instructionsParts) > 0 {
|
||||
instructions := strings.Join(instructionsParts, "\n\n")
|
||||
instructionsRaw, _ = common.Marshal(instructions)
|
||||
}
|
||||
|
||||
var toolsRaw json.RawMessage
|
||||
if req.Tools != nil {
|
||||
tools := make([]map[string]any, 0, len(req.Tools))
|
||||
for _, tool := range req.Tools {
|
||||
switch tool.Type {
|
||||
case "function":
|
||||
tools = append(tools, map[string]any{
|
||||
"type": "function",
|
||||
"name": tool.Function.Name,
|
||||
"description": tool.Function.Description,
|
||||
"parameters": tool.Function.Parameters,
|
||||
})
|
||||
default:
|
||||
// Best-effort: keep original tool shape for unknown types.
|
||||
var m map[string]any
|
||||
if b, err := common.Marshal(tool); err == nil {
|
||||
_ = common.Unmarshal(b, &m)
|
||||
}
|
||||
if len(m) == 0 {
|
||||
m = map[string]any{"type": tool.Type}
|
||||
}
|
||||
tools = append(tools, m)
|
||||
}
|
||||
}
|
||||
toolsRaw, _ = common.Marshal(tools)
|
||||
}
|
||||
|
||||
var toolChoiceRaw json.RawMessage
|
||||
if req.ToolChoice != nil {
|
||||
switch v := req.ToolChoice.(type) {
|
||||
case string:
|
||||
toolChoiceRaw, _ = common.Marshal(v)
|
||||
default:
|
||||
var m map[string]any
|
||||
if b, err := common.Marshal(v); err == nil {
|
||||
_ = common.Unmarshal(b, &m)
|
||||
}
|
||||
if m == nil {
|
||||
toolChoiceRaw, _ = common.Marshal(v)
|
||||
} else if t, _ := m["type"].(string); t == "function" {
|
||||
// Chat: {"type":"function","function":{"name":"..."}}
|
||||
// Responses: {"type":"function","name":"..."}
|
||||
if name, ok := m["name"].(string); ok && name != "" {
|
||||
toolChoiceRaw, _ = common.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
})
|
||||
} else if fn, ok := m["function"].(map[string]any); ok {
|
||||
if name, ok := fn["name"].(string); ok && name != "" {
|
||||
toolChoiceRaw, _ = common.Marshal(map[string]any{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
})
|
||||
} else {
|
||||
toolChoiceRaw, _ = common.Marshal(v)
|
||||
}
|
||||
} else {
|
||||
toolChoiceRaw, _ = common.Marshal(v)
|
||||
}
|
||||
} else {
|
||||
toolChoiceRaw, _ = common.Marshal(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var parallelToolCallsRaw json.RawMessage
|
||||
if req.ParallelTooCalls != nil {
|
||||
parallelToolCallsRaw, _ = common.Marshal(*req.ParallelTooCalls)
|
||||
}
|
||||
|
||||
var textRaw json.RawMessage
|
||||
if req.ResponseFormat != nil && req.ResponseFormat.Type != "" {
|
||||
textRaw, _ = common.Marshal(map[string]any{
|
||||
"format": req.ResponseFormat,
|
||||
})
|
||||
}
|
||||
|
||||
maxOutputTokens := req.MaxTokens
|
||||
if req.MaxCompletionTokens > maxOutputTokens {
|
||||
maxOutputTokens = req.MaxCompletionTokens
|
||||
}
|
||||
|
||||
var topP *float64
|
||||
if req.TopP != 0 {
|
||||
topP = common.GetPointer(req.TopP)
|
||||
}
|
||||
|
||||
out := &dto.OpenAIResponsesRequest{
|
||||
Model: req.Model,
|
||||
Input: inputRaw,
|
||||
Instructions: instructionsRaw,
|
||||
MaxOutputTokens: maxOutputTokens,
|
||||
Stream: req.Stream,
|
||||
Temperature: req.Temperature,
|
||||
Text: textRaw,
|
||||
ToolChoice: toolChoiceRaw,
|
||||
Tools: toolsRaw,
|
||||
TopP: topP,
|
||||
User: req.User,
|
||||
ParallelToolCalls: parallelToolCallsRaw,
|
||||
Store: req.Store,
|
||||
Metadata: req.Metadata,
|
||||
}
|
||||
|
||||
if req.ReasoningEffort != "" && req.ReasoningEffort != "none" {
|
||||
out.Reasoning = &dto.Reasoning{
|
||||
Effort: req.ReasoningEffort,
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
18
service/openaicompat/policy.go
Normal file
18
service/openaicompat/policy.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package openaicompat
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/model_setting"
|
||||
|
||||
func ShouldChatCompletionsUseResponsesPolicy(policy model_setting.ChatCompletionsToResponsesPolicy, channelID int, model string) bool {
|
||||
if !policy.IsChannelEnabled(channelID) {
|
||||
return false
|
||||
}
|
||||
return matchAnyRegex(policy.ModelPatterns, model)
|
||||
}
|
||||
|
||||
func ShouldChatCompletionsUseResponsesGlobal(channelID int, model string) bool {
|
||||
return ShouldChatCompletionsUseResponsesPolicy(
|
||||
model_setting.GetGlobalSettings().ChatCompletionsToResponsesPolicy,
|
||||
channelID,
|
||||
model,
|
||||
)
|
||||
}
|
||||
33
service/openaicompat/regex.go
Normal file
33
service/openaicompat/regex.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package openaicompat
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var compiledRegexCache sync.Map // map[string]*regexp.Regexp
|
||||
|
||||
func matchAnyRegex(patterns []string, s string) bool {
|
||||
if len(patterns) == 0 || s == "" {
|
||||
return false
|
||||
}
|
||||
for _, pattern := range patterns {
|
||||
if pattern == "" {
|
||||
continue
|
||||
}
|
||||
re, ok := compiledRegexCache.Load(pattern)
|
||||
if !ok {
|
||||
compiled, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
// Treat invalid patterns as non-matching to avoid breaking runtime traffic.
|
||||
continue
|
||||
}
|
||||
re = compiled
|
||||
compiledRegexCache.Store(pattern, re)
|
||||
}
|
||||
if re.(*regexp.Regexp).MatchString(s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
133
service/openaicompat/responses_to_chat.go
Normal file
133
service/openaicompat/responses_to_chat.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package openaicompat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
)
|
||||
|
||||
func ResponsesResponseToChatCompletionsResponse(resp *dto.OpenAIResponsesResponse, id string) (*dto.OpenAITextResponse, *dto.Usage, error) {
|
||||
if resp == nil {
|
||||
return nil, nil, errors.New("response is nil")
|
||||
}
|
||||
|
||||
text := ExtractOutputTextFromResponses(resp)
|
||||
|
||||
usage := &dto.Usage{}
|
||||
if resp.Usage != nil {
|
||||
if resp.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = resp.Usage.InputTokens
|
||||
usage.InputTokens = resp.Usage.InputTokens
|
||||
}
|
||||
if resp.Usage.OutputTokens != 0 {
|
||||
usage.CompletionTokens = resp.Usage.OutputTokens
|
||||
usage.OutputTokens = resp.Usage.OutputTokens
|
||||
}
|
||||
if resp.Usage.TotalTokens != 0 {
|
||||
usage.TotalTokens = resp.Usage.TotalTokens
|
||||
} else {
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
}
|
||||
if resp.Usage.InputTokensDetails != nil {
|
||||
usage.PromptTokensDetails.CachedTokens = resp.Usage.InputTokensDetails.CachedTokens
|
||||
usage.PromptTokensDetails.ImageTokens = resp.Usage.InputTokensDetails.ImageTokens
|
||||
usage.PromptTokensDetails.AudioTokens = resp.Usage.InputTokensDetails.AudioTokens
|
||||
}
|
||||
if resp.Usage.CompletionTokenDetails.ReasoningTokens != 0 {
|
||||
usage.CompletionTokenDetails.ReasoningTokens = resp.Usage.CompletionTokenDetails.ReasoningTokens
|
||||
}
|
||||
}
|
||||
|
||||
created := resp.CreatedAt
|
||||
|
||||
var toolCalls []dto.ToolCallResponse
|
||||
if text == "" && len(resp.Output) > 0 {
|
||||
for _, out := range resp.Output {
|
||||
if out.Type != "function_call" {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(out.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
callId := strings.TrimSpace(out.CallId)
|
||||
if callId == "" {
|
||||
callId = strings.TrimSpace(out.ID)
|
||||
}
|
||||
toolCalls = append(toolCalls, dto.ToolCallResponse{
|
||||
ID: callId,
|
||||
Type: "function",
|
||||
Function: dto.FunctionResponse{
|
||||
Name: name,
|
||||
Arguments: out.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
msg := dto.Message{
|
||||
Role: "assistant",
|
||||
Content: text,
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
msg.SetToolCalls(toolCalls)
|
||||
msg.Content = ""
|
||||
}
|
||||
|
||||
out := &dto.OpenAITextResponse{
|
||||
Id: id,
|
||||
Object: "chat.completion",
|
||||
Created: created,
|
||||
Model: resp.Model,
|
||||
Choices: []dto.OpenAITextResponseChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: msg,
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: *usage,
|
||||
}
|
||||
|
||||
return out, usage, nil
|
||||
}
|
||||
|
||||
func ExtractOutputTextFromResponses(resp *dto.OpenAIResponsesResponse) string {
|
||||
if resp == nil || len(resp.Output) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Prefer assistant message outputs.
|
||||
for _, out := range resp.Output {
|
||||
if out.Type != "message" {
|
||||
continue
|
||||
}
|
||||
if out.Role != "" && out.Role != "assistant" {
|
||||
continue
|
||||
}
|
||||
for _, c := range out.Content {
|
||||
if c.Type == "output_text" && c.Text != "" {
|
||||
sb.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
return sb.String()
|
||||
}
|
||||
for _, out := range resp.Output {
|
||||
for _, c := range out.Content {
|
||||
if c.Text != "" {
|
||||
sb.WriteString(c.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
|
||||
token, err := model.GetTokenByKey(strings.TrimPrefix(relayInfo.TokenKey, "sk-"), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,14 +1,36 @@
|
||||
package model_setting
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/setting/config"
|
||||
)
|
||||
|
||||
type ChatCompletionsToResponsesPolicy struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
AllChannels bool `json:"all_channels"`
|
||||
ChannelIDs []int `json:"channel_ids,omitempty"`
|
||||
ModelPatterns []string `json:"model_patterns,omitempty"`
|
||||
}
|
||||
|
||||
func (p ChatCompletionsToResponsesPolicy) IsChannelEnabled(channelID int) bool {
|
||||
if !p.Enabled {
|
||||
return false
|
||||
}
|
||||
if p.AllChannels {
|
||||
return true
|
||||
}
|
||||
if channelID == 0 || len(p.ChannelIDs) == 0 {
|
||||
return false
|
||||
}
|
||||
return slices.Contains(p.ChannelIDs, channelID)
|
||||
}
|
||||
|
||||
type GlobalSettings struct {
|
||||
PassThroughRequestEnabled bool `json:"pass_through_request_enabled"`
|
||||
ThinkingModelBlacklist []string `json:"thinking_model_blacklist"`
|
||||
PassThroughRequestEnabled bool `json:"pass_through_request_enabled"`
|
||||
ThinkingModelBlacklist []string `json:"thinking_model_blacklist"`
|
||||
ChatCompletionsToResponsesPolicy ChatCompletionsToResponsesPolicy `json:"chat_completions_to_responses_policy"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
@@ -18,6 +40,10 @@ var defaultOpenaiSettings = GlobalSettings{
|
||||
"moonshotai/kimi-k2-thinking",
|
||||
"kimi-k2-thinking",
|
||||
},
|
||||
ChatCompletionsToResponsesPolicy: ChatCompletionsToResponsesPolicy{
|
||||
Enabled: false,
|
||||
AllChannels: true,
|
||||
},
|
||||
}
|
||||
|
||||
// 全局实例
|
||||
|
||||
37
setting/operation_setting/checkin_setting.go
Normal file
37
setting/operation_setting/checkin_setting.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package operation_setting
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
// CheckinSetting 签到功能配置
|
||||
type CheckinSetting struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用签到功能
|
||||
MinQuota int `json:"min_quota"` // 签到最小额度奖励
|
||||
MaxQuota int `json:"max_quota"` // 签到最大额度奖励
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var checkinSetting = CheckinSetting{
|
||||
Enabled: false, // 默认关闭
|
||||
MinQuota: 1000, // 默认最小额度 1000 (约 0.002 USD)
|
||||
MaxQuota: 10000, // 默认最大额度 10000 (约 0.02 USD)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("checkin_setting", &checkinSetting)
|
||||
}
|
||||
|
||||
// GetCheckinSetting 获取签到配置
|
||||
func GetCheckinSetting() *CheckinSetting {
|
||||
return &checkinSetting
|
||||
}
|
||||
|
||||
// IsCheckinEnabled 是否启用签到功能
|
||||
func IsCheckinEnabled() bool {
|
||||
return checkinSetting.Enabled
|
||||
}
|
||||
|
||||
// GetCheckinQuotaRange 获取签到额度范围
|
||||
func GetCheckinQuotaRange() (min, max int) {
|
||||
return checkinSetting.MinQuota, checkinSetting.MaxQuota
|
||||
}
|
||||
147
setting/operation_setting/status_code_ranges.go
Normal file
147
setting/operation_setting/status_code_ranges.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StatusCodeRange struct {
|
||||
Start int
|
||||
End int
|
||||
}
|
||||
|
||||
var AutomaticDisableStatusCodeRanges = []StatusCodeRange{{Start: 401, End: 401}}
|
||||
|
||||
func AutomaticDisableStatusCodesToString() string {
|
||||
if len(AutomaticDisableStatusCodeRanges) == 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, 0, len(AutomaticDisableStatusCodeRanges))
|
||||
for _, r := range AutomaticDisableStatusCodeRanges {
|
||||
if r.Start == r.End {
|
||||
parts = append(parts, strconv.Itoa(r.Start))
|
||||
continue
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("%d-%d", r.Start, r.End))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func AutomaticDisableStatusCodesFromString(s string) error {
|
||||
ranges, err := ParseHTTPStatusCodeRanges(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
AutomaticDisableStatusCodeRanges = ranges
|
||||
return nil
|
||||
}
|
||||
|
||||
func ShouldDisableByStatusCode(code int) bool {
|
||||
if code < 100 || code > 599 {
|
||||
return false
|
||||
}
|
||||
for _, r := range AutomaticDisableStatusCodeRanges {
|
||||
if code < r.Start {
|
||||
return false
|
||||
}
|
||||
if code <= r.End {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ParseHTTPStatusCodeRanges(input string) ([]StatusCodeRange, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
input = strings.NewReplacer(",", ",").Replace(input)
|
||||
segments := strings.Split(input, ",")
|
||||
|
||||
var ranges []StatusCodeRange
|
||||
var invalid []string
|
||||
|
||||
for _, seg := range segments {
|
||||
seg = strings.TrimSpace(seg)
|
||||
if seg == "" {
|
||||
continue
|
||||
}
|
||||
r, err := parseHTTPStatusCodeToken(seg)
|
||||
if err != nil {
|
||||
invalid = append(invalid, seg)
|
||||
continue
|
||||
}
|
||||
ranges = append(ranges, r)
|
||||
}
|
||||
|
||||
if len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("invalid http status code rules: %s", strings.Join(invalid, ", "))
|
||||
}
|
||||
if len(ranges) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Slice(ranges, func(i, j int) bool {
|
||||
if ranges[i].Start == ranges[j].Start {
|
||||
return ranges[i].End < ranges[j].End
|
||||
}
|
||||
return ranges[i].Start < ranges[j].Start
|
||||
})
|
||||
|
||||
merged := []StatusCodeRange{ranges[0]}
|
||||
for _, r := range ranges[1:] {
|
||||
last := &merged[len(merged)-1]
|
||||
if r.Start <= last.End+1 {
|
||||
if r.End > last.End {
|
||||
last.End = r.End
|
||||
}
|
||||
continue
|
||||
}
|
||||
merged = append(merged, r)
|
||||
}
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
func parseHTTPStatusCodeToken(token string) (StatusCodeRange, error) {
|
||||
token = strings.TrimSpace(token)
|
||||
token = strings.ReplaceAll(token, " ", "")
|
||||
if token == "" {
|
||||
return StatusCodeRange{}, fmt.Errorf("empty token")
|
||||
}
|
||||
|
||||
if strings.Contains(token, "-") {
|
||||
parts := strings.Split(token, "-")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid range token: %s", token)
|
||||
}
|
||||
start, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid range start: %s", token)
|
||||
}
|
||||
end, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid range end: %s", token)
|
||||
}
|
||||
if start > end {
|
||||
return StatusCodeRange{}, fmt.Errorf("range start > end: %s", token)
|
||||
}
|
||||
if start < 100 || end > 599 {
|
||||
return StatusCodeRange{}, fmt.Errorf("range out of bounds: %s", token)
|
||||
}
|
||||
return StatusCodeRange{Start: start, End: end}, nil
|
||||
}
|
||||
|
||||
code, err := strconv.Atoi(token)
|
||||
if err != nil {
|
||||
return StatusCodeRange{}, fmt.Errorf("invalid status code: %s", token)
|
||||
}
|
||||
if code < 100 || code > 599 {
|
||||
return StatusCodeRange{}, fmt.Errorf("status code out of bounds: %s", token)
|
||||
}
|
||||
return StatusCodeRange{Start: code, End: code}, nil
|
||||
}
|
||||
52
setting/operation_setting/status_code_ranges_test.go
Normal file
52
setting/operation_setting/status_code_ranges_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_CommaSeparated(t *testing.T) {
|
||||
ranges, err := ParseHTTPStatusCodeRanges("401,403,500-599")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []StatusCodeRange{
|
||||
{Start: 401, End: 401},
|
||||
{Start: 403, End: 403},
|
||||
{Start: 500, End: 599},
|
||||
}, ranges)
|
||||
}
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_MergeAndNormalize(t *testing.T) {
|
||||
ranges, err := ParseHTTPStatusCodeRanges("500-505,504,401,403,402")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []StatusCodeRange{
|
||||
{Start: 401, End: 403},
|
||||
{Start: 500, End: 505},
|
||||
}, ranges)
|
||||
}
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_Invalid(t *testing.T) {
|
||||
_, err := ParseHTTPStatusCodeRanges("99,600,foo,500-400,500-")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseHTTPStatusCodeRanges_NoComma_IsInvalid(t *testing.T) {
|
||||
_, err := ParseHTTPStatusCodeRanges("401 403")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestShouldDisableByStatusCode(t *testing.T) {
|
||||
orig := AutomaticDisableStatusCodeRanges
|
||||
t.Cleanup(func() { AutomaticDisableStatusCodeRanges = orig })
|
||||
|
||||
AutomaticDisableStatusCodeRanges = []StatusCodeRange{
|
||||
{Start: 401, End: 403},
|
||||
{Start: 500, End: 599},
|
||||
}
|
||||
|
||||
require.True(t, ShouldDisableByStatusCode(401))
|
||||
require.True(t, ShouldDisableByStatusCode(403))
|
||||
require.False(t, ShouldDisableByStatusCode(404))
|
||||
require.True(t, ShouldDisableByStatusCode(500))
|
||||
require.False(t, ShouldDisableByStatusCode(200))
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/QuantumNous/new-api/setting/reasoning"
|
||||
)
|
||||
|
||||
// from songquanpeng/one-api
|
||||
@@ -312,6 +311,10 @@ var defaultAudioCompletionRatio = map[string]float64{
|
||||
"gpt-4o-realtime": 2,
|
||||
"gpt-4o-mini-realtime": 2,
|
||||
"gpt-4o-mini-tts": 1,
|
||||
"tts-1": 0,
|
||||
"tts-1-hd": 0,
|
||||
"tts-1-1106": 0,
|
||||
"tts-1-hd-1106": 0,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -657,7 +660,7 @@ func GetAudioRatio(name string) float64 {
|
||||
if ratio, ok := audioRatioMap[name]; ok {
|
||||
return ratio
|
||||
}
|
||||
return 20
|
||||
return 1
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatio(name string) float64 {
|
||||
@@ -668,7 +671,23 @@ func GetAudioCompletionRatio(name string) float64 {
|
||||
|
||||
return ratio
|
||||
}
|
||||
return 2
|
||||
return 1
|
||||
}
|
||||
|
||||
func ContainsAudioRatio(name string) bool {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
_, ok := audioRatioMap[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func ContainsAudioCompletionRatio(name string) bool {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
name = FormatMatchingModelName(name)
|
||||
_, ok := audioCompletionRatioMap[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
@@ -746,16 +765,6 @@ func UpdateAudioRatioByJSONString(jsonStr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAudioRatioCopy() map[string]float64 {
|
||||
audioRatioMapMutex.RLock()
|
||||
defer audioRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(audioRatioMap))
|
||||
for k, v := range audioRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func AudioCompletionRatio2JSONString() string {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
@@ -778,16 +787,6 @@ func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetAudioCompletionRatioCopy() map[string]float64 {
|
||||
audioCompletionRatioMapMutex.RLock()
|
||||
defer audioCompletionRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(audioCompletionRatioMap))
|
||||
for k, v := range audioCompletionRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func GetModelRatioCopy() map[string]float64 {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
@@ -829,10 +828,6 @@ func FormatMatchingModelName(name string) string {
|
||||
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
|
||||
}
|
||||
|
||||
if base, _, ok := reasoning.TrimEffortSuffix(name); ok {
|
||||
name = base
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"github.com/samber/lo"
|
||||
)
|
||||
|
||||
var EffortSuffixes = []string{"-high", "-medium", "-low"}
|
||||
var EffortSuffixes = []string{"-high", "-medium", "-low", "-minimal"}
|
||||
|
||||
// TrimEffortSuffix -> modelName level(low) exists
|
||||
func TrimEffortSuffix(modelName string) (string, string, bool) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -10,10 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
@@ -92,6 +94,7 @@ type NewAPIError struct {
|
||||
errorType ErrorType
|
||||
errorCode ErrorCode
|
||||
StatusCode int
|
||||
Metadata json.RawMessage
|
||||
}
|
||||
|
||||
// Unwrap enables errors.Is / errors.As to work with NewAPIError by exposing the underlying error.
|
||||
@@ -127,6 +130,20 @@ func (e *NewAPIError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *NewAPIError) ErrorWithStatusCode() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
msg := e.Error()
|
||||
if e.StatusCode == 0 {
|
||||
return msg
|
||||
}
|
||||
if msg == "" {
|
||||
return fmt.Sprintf("status_code=%d", e.StatusCode)
|
||||
}
|
||||
return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
|
||||
}
|
||||
|
||||
func (e *NewAPIError) MaskSensitiveError() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
@@ -141,6 +158,20 @@ func (e *NewAPIError) MaskSensitiveError() string {
|
||||
return common.MaskSensitiveInfo(errStr)
|
||||
}
|
||||
|
||||
func (e *NewAPIError) MaskSensitiveErrorWithStatusCode() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
msg := e.MaskSensitiveError()
|
||||
if e.StatusCode == 0 {
|
||||
return msg
|
||||
}
|
||||
if msg == "" {
|
||||
return fmt.Sprintf("status_code=%d", e.StatusCode)
|
||||
}
|
||||
return fmt.Sprintf("status_code=%d, %s", e.StatusCode, msg)
|
||||
}
|
||||
|
||||
func (e *NewAPIError) SetMessage(message string) {
|
||||
e.Err = errors.New(message)
|
||||
}
|
||||
@@ -301,6 +332,13 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIError
|
||||
Err: errors.New(openAIError.Message),
|
||||
errorCode: ErrorCode(code),
|
||||
}
|
||||
// OpenRouter
|
||||
if len(openAIError.Metadata) > 0 {
|
||||
openAIError.Message = fmt.Sprintf("%s (%s)", openAIError.Message, openAIError.Metadata)
|
||||
e.Metadata = openAIError.Metadata
|
||||
e.RelayError = openAIError
|
||||
e.Err = errors.New(openAIError.Message)
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(e)
|
||||
}
|
||||
|
||||
@@ -26,12 +26,22 @@ type PriceData struct {
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p *PriceData) AddOtherRatio(key string, ratio float64) {
|
||||
if p.OtherRatios == nil {
|
||||
p.OtherRatios = make(map[string]float64)
|
||||
}
|
||||
if ratio <= 0 {
|
||||
return
|
||||
}
|
||||
p.OtherRatios[key] = ratio
|
||||
}
|
||||
|
||||
type PerCallPriceData struct {
|
||||
ModelPrice float64
|
||||
Quota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
func (p *PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
||||
}
|
||||
|
||||
@@ -25,7 +25,9 @@ export default defineConfig({
|
||||
"zh",
|
||||
"en",
|
||||
"fr",
|
||||
"ru"
|
||||
"ru",
|
||||
"ja",
|
||||
"vi"
|
||||
],
|
||||
extract: {
|
||||
input: [
|
||||
|
||||
@@ -42,6 +42,7 @@ import Midjourney from './pages/Midjourney';
|
||||
import Pricing from './pages/Pricing';
|
||||
import Task from './pages/Task';
|
||||
import ModelPage from './pages/Model';
|
||||
import ModelDeploymentPage from './pages/ModelDeployment';
|
||||
import Playground from './pages/Playground';
|
||||
import OAuth2Callback from './components/auth/OAuth2Callback';
|
||||
import PersonalSetting from './components/settings/PersonalSetting';
|
||||
@@ -108,6 +109,14 @@ function App() {
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/console/deployment'
|
||||
element={
|
||||
<AdminRoute>
|
||||
<ModelDeploymentPage />
|
||||
</AdminRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/console/channel'
|
||||
element={
|
||||
|
||||
@@ -59,6 +59,11 @@ import { SiDiscord }from 'react-icons/si';
|
||||
const LoginForm = () => {
|
||||
let navigate = useNavigate();
|
||||
const { t } = useTranslation();
|
||||
const githubButtonTextKeyByState = {
|
||||
idle: '使用 GitHub 继续',
|
||||
redirecting: '正在跳转 GitHub...',
|
||||
timeout: '请求超时,请刷新页面后重新发起 GitHub 登录',
|
||||
};
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
password: '',
|
||||
@@ -90,9 +95,10 @@ const LoginForm = () => {
|
||||
const [agreedToTerms, setAgreedToTerms] = useState(false);
|
||||
const [hasUserAgreement, setHasUserAgreement] = useState(false);
|
||||
const [hasPrivacyPolicy, setHasPrivacyPolicy] = useState(false);
|
||||
const [githubButtonText, setGithubButtonText] = useState('使用 GitHub 继续');
|
||||
const [githubButtonState, setGithubButtonState] = useState('idle');
|
||||
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
|
||||
const githubTimeoutRef = useRef(null);
|
||||
const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]);
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -284,13 +290,13 @@ const LoginForm = () => {
|
||||
}
|
||||
setGithubLoading(true);
|
||||
setGithubButtonDisabled(true);
|
||||
setGithubButtonText(t('正在跳转 GitHub...'));
|
||||
setGithubButtonState('redirecting');
|
||||
if (githubTimeoutRef.current) {
|
||||
clearTimeout(githubTimeoutRef.current);
|
||||
}
|
||||
githubTimeoutRef.current = setTimeout(() => {
|
||||
setGithubLoading(false);
|
||||
setGithubButtonText(t('请求超时,请刷新页面后重新发起 GitHub 登录'));
|
||||
setGithubButtonState('timeout');
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
|
||||
@@ -57,6 +57,11 @@ import { SiDiscord } from 'react-icons/si';
|
||||
const RegisterForm = () => {
|
||||
let navigate = useNavigate();
|
||||
const { t } = useTranslation();
|
||||
const githubButtonTextKeyByState = {
|
||||
idle: '使用 GitHub 继续',
|
||||
redirecting: '正在跳转 GitHub...',
|
||||
timeout: '请求超时,请刷新页面后重新发起 GitHub 登录',
|
||||
};
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
password: '',
|
||||
@@ -88,9 +93,10 @@ const RegisterForm = () => {
|
||||
const [agreedToTerms, setAgreedToTerms] = useState(false);
|
||||
const [hasUserAgreement, setHasUserAgreement] = useState(false);
|
||||
const [hasPrivacyPolicy, setHasPrivacyPolicy] = useState(false);
|
||||
const [githubButtonText, setGithubButtonText] = useState('使用 GitHub 继续');
|
||||
const [githubButtonState, setGithubButtonState] = useState('idle');
|
||||
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
|
||||
const githubTimeoutRef = useRef(null);
|
||||
const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]);
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -251,13 +257,13 @@ const RegisterForm = () => {
|
||||
}
|
||||
setGithubLoading(true);
|
||||
setGithubButtonDisabled(true);
|
||||
setGithubButtonText(t('正在跳转 GitHub...'));
|
||||
setGithubButtonState('redirecting');
|
||||
if (githubTimeoutRef.current) {
|
||||
clearTimeout(githubTimeoutRef.current);
|
||||
}
|
||||
githubTimeoutRef.current = setTimeout(() => {
|
||||
setGithubLoading(false);
|
||||
setGithubButtonText(t('请求超时,请刷新页面后重新发起 GitHub 登录'));
|
||||
setGithubButtonState('timeout');
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
|
||||
@@ -45,6 +45,7 @@ const routerMap = {
|
||||
pricing: '/pricing',
|
||||
task: '/console/task',
|
||||
models: '/console/models',
|
||||
deployment: '/console/deployment',
|
||||
playground: '/console/playground',
|
||||
personal: '/console/personal',
|
||||
};
|
||||
@@ -157,6 +158,12 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
to: '/console/models',
|
||||
className: isAdmin() ? '' : 'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: t('模型部署'),
|
||||
itemKey: 'deployment',
|
||||
to: '/deployment',
|
||||
className: isAdmin() ? '' : 'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: t('兑换码管理'),
|
||||
itemKey: 'redemption',
|
||||
|
||||
@@ -52,7 +52,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: isMobile ? 40 : width, height }}
|
||||
/>
|
||||
}
|
||||
@@ -71,7 +70,7 @@ const SkeletonWrapper = ({
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar active size='extra-small' className='shadow-sm' />
|
||||
<Skeleton.Avatar size='extra-small' className='shadow-sm' />
|
||||
}
|
||||
/>
|
||||
<div className='ml-1.5 mr-1'>
|
||||
@@ -80,7 +79,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: isMobile ? 15 : width, height: 12 }}
|
||||
/>
|
||||
}
|
||||
@@ -98,7 +96,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Image
|
||||
active
|
||||
className={`absolute inset-0 !rounded-full ${className}`}
|
||||
style={{ width: '100%', height: '100%' }}
|
||||
/>
|
||||
@@ -113,7 +110,7 @@ const SkeletonWrapper = ({
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={<Skeleton.Title active style={{ width, height: 24 }} />}
|
||||
placeholder={<Skeleton.Title style={{ width, height: 24 }} />}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -125,7 +122,7 @@ const SkeletonWrapper = ({
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={<Skeleton.Title active style={{ width, height }} />}
|
||||
placeholder={<Skeleton.Title style={{ width, height }} />}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
@@ -140,7 +137,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width, height, borderRadius: 9999 }}
|
||||
/>
|
||||
}
|
||||
@@ -164,7 +160,7 @@ const SkeletonWrapper = ({
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar active size='extra-small' shape='square' />
|
||||
<Skeleton.Avatar size='extra-small' shape='square' />
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
@@ -174,7 +170,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: width || 80, height: height || 14 }}
|
||||
/>
|
||||
}
|
||||
@@ -191,10 +186,7 @@ const SkeletonWrapper = ({
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: width || 60, height: height || 12 }}
|
||||
/>
|
||||
<Skeleton.Title style={{ width: width || 60, height: height || 12 }} />
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
@@ -217,7 +209,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar
|
||||
active
|
||||
shape='square'
|
||||
style={{ width: ICON_SIZE, height: ICON_SIZE }}
|
||||
/>
|
||||
@@ -231,7 +222,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: labelWidth, height: TEXT_HEIGHT }}
|
||||
/>
|
||||
}
|
||||
@@ -269,7 +259,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar
|
||||
active
|
||||
shape='square'
|
||||
style={{ width: ICON_SIZE, height: ICON_SIZE }}
|
||||
/>
|
||||
@@ -329,7 +318,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
|
||||
/>
|
||||
}
|
||||
@@ -350,7 +338,6 @@ const SkeletonWrapper = ({
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
|
||||
/>
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user