Compare commits

..

94 Commits

Author SHA1 Message Date
CaIon
ad61c0f89e fix(gin): update request body size check to allow zero limit 2026-01-05 18:55:24 +08:00
Seefs
9addf1b705 Merge pull request #2581 from seefs001/fix/batch-add-key-deduplicate 2026-01-05 18:52:18 +08:00
Seefs
9e61338a6f Merge pull request #2582 from seefs001/fix/tips
fix: add tips for model management and channel testing
2026-01-05 18:47:02 +08:00
Calcium-Ion
d3f33932c0 Merge pull request #2580 from seefs001/fix/aws-proxy-timeout
fix: fix the proxyURL is empty, not using the default HTTP client configuration && the AWS calling side did not apply the relay timeout.
2026-01-05 18:32:25 +08:00
Seefs
a8f7c0614f fix: batch add key backend deduplication 2026-01-05 18:09:02 +08:00
Seefs
5f37a1e97c fix: fix the proxyURL is empty, not using the default HTTP client configuration && the AWS calling side did not apply the relay timeout. 2026-01-05 17:56:24 +08:00
Calcium-Ion
177553af37 Merge pull request #2578 from xyfacai/fix/gemini-mimetype
fix: 修复 gemini 文件类型不支持 image/jpg
2026-01-04 22:19:16 +08:00
Xyfacai
5ed4583c0c fix: 修复 gemini 文件类型不支持 image/jpg 2026-01-04 22:09:03 +08:00
Seefs
1519e97bc6 Merge pull request #2550 from shikaiwei1/patch-2 2026-01-04 18:11:46 +08:00
CaIon
443b05821f feat: add plans directory to .gitignore 2026-01-04 16:20:58 +08:00
Seefs
6b1f8b84c6 Merge pull request #2568 from seefs001/feature/channel_override_trim_prefix 2026-01-03 12:38:32 +08:00
Seefs
22d0b73d21 fix: fix model deployment style issues, lint problems, and i18n gaps. (#2556)
* fix: fix model deployment style issues, lint problems, and i18n gaps.

* fix: adjust the key not to be displayed on the frontend, tested via the backend.

* fix: adjust the sidebar configuration logic to use the default configuration items if they are not defined.
2026-01-03 12:37:50 +08:00
Calcium-Ion
e8aaed440c Merge pull request #2558 from seefs001/fix/gemini-tool-call
fix: gemini request -> openai tool call
2026-01-03 12:37:28 +08:00
Calcium-Ion
cfc72d6817 Merge pull request #2571 from seefs001/feature/check-in-security-check
feat: check-in feature integrates Turnstile security check
2026-01-03 12:36:39 +08:00
Seefs
67ba913b44 feat: add support for Doubao /v1/responses (#2567)
* feat: add support for Doubao /v1/responses
2026-01-03 12:35:35 +08:00
Seefs
1f78c6a0f9 Merge pull request #2570 from feitianbubu/pr/43f64c6508515ffaec308ac9c1cf2afa2de98c3d 2026-01-03 12:16:57 +08:00
Seefs
9cf756fc4b Merge pull request #2447 from a4399518s/main 2026-01-03 12:15:28 +08:00
Seefs
c33ac97c71 feat: check-in feature integrates Turnstile security check 2026-01-03 11:08:26 +08:00
feitianbubu
c682e41338 fix: CrossGroupRetry default false
移除gorm:"default:false",避免每次 AutoMigrate时都执行ALTER TABLE `tokens` MODIFY COLUMN `cross_group_retry` boolean DEFAULT false
且bool默认false不影响原有功能
2026-01-03 10:43:33 +08:00
Seefs
817da8d73c feat: add parameter coverage for the operations: copy, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, and regex_replace 2026-01-03 10:27:16 +08:00
Seefs
43c671b8b3 Merge pull request #2393 from prnake/fix-claude-haiku 2026-01-03 09:36:42 +08:00
Seefs
3510b3e6fc Merge pull request #2532 from feitianbubu/pr/620211e02bd55545f0fa4568f3d55c3b4d7f3305 2026-01-03 09:36:17 +08:00
Seefs
f1ba624df2 Merge pull request #2547 from wwalt1a/feat/support-proxy-env-vars 2026-01-03 09:35:51 +08:00
Seefs
2b8cbbe5ae Merge pull request #2425 from atopos31/main 2026-01-03 09:32:50 +08:00
Seefs
fdd1ac0f6e Merge pull request #2554 from zpc7/bugfix/remove-duplicate-condition 2026-01-03 09:31:22 +08:00
Seefs
e5679109d4 Merge pull request #2566 from RedwindA/fix/checkin-frontend-collapse 2026-01-03 09:26:49 +08:00
RedwindA
be8e644546 fix: remove a duplicate key in i18n 2026-01-03 00:55:08 +08:00
RedwindA
a0328b2f5e fix(checkin): prevent visual flicker when loading check-in component
- Add initialLoaded state to track first data load completion
- Set isCollapsed to null initially, determined after data loads
- Show loading state on button and description text before data arrives
- Remove auto-collapse effect that caused visual flicker
- Add i18n translations for loading states (en/fr/ja/ru/vi/zh)

Fixes issue where component would collapse/expand after data loads,
causing visual flicker when navigating to personal settings page.
2026-01-03 00:43:52 +08:00
Calcium-Ion
3402d09ed0 Merge pull request #2565 from QuantumNous/feat/check-in
feat(checkin): add check-in functionality
2026-01-02 23:17:10 +08:00
CaIon
8abfbe372f feat(checkin): add check-in functionality with status retrieval and user quota rewards 2026-01-02 23:00:33 +08:00
CaIon
a195e88896 fix: 修复 timestamp2string1 跨年显示问题,仅在数据跨年时显示年份 2026-01-01 15:42:15 +08:00
CaIon
d06915c30d feat(ratio): add functions to check for audio ratios and clean up unused code 2025-12-31 21:29:10 +08:00
CaIon
b1bb64ae11 feat(model): add audio ratios for new TTS models and adjust default values 2025-12-31 21:22:33 +08:00
CaIon
b2d8ad7883 feat(init): increase maximum file download size to 64MB 2025-12-31 21:15:37 +08:00
Seefs
ddb40b1a6e fix: gemini request -> openai tool call 2025-12-31 18:09:21 +08:00
PCCCCCCC
8b790446ce remove duplicate condition in TaskLogsColumnDefs 2025-12-31 09:38:23 +08:00
CaIon
b808b96cce fix(TaskLogs): use correct video URL for modal preview 2025-12-31 00:44:12 +08:00
CaIon
23a68137ad feat(adaptor): update resolution handling for wan2.6 model 2025-12-31 00:44:06 +08:00
CaIon
2a5b2add9a refactor(image): remove unnecessary logging in oaiImage2Ali function 2025-12-31 00:23:19 +08:00
Calcium-Ion
11922ef651 Merge pull request #2551 from feitianbubu/pr/829cb06b5d689ecbcc05bb3ef49dbf1aec427c35
feat: flush response writer after copying body
2025-12-30 18:09:33 +08:00
feitianbubu
d474ed4778 feat: flush response writer after copying body 2025-12-30 17:52:57 +08:00
John Chen
ab81d6e444 fix: 修复智普、Moonshot渠道在stream=true时无法拿到cachePrompt的统计数据。
根本原因:
1. 在OaiStreamHandler流式处理函数中,调用applyUsagePostProcessing(info, usage, nil)时传入的responseBody为nil,导致无法从响应体中提取缓存tokens。
2. 两个渠道的cached_tokens位置不同:
  - 智普:标准位置 usage.prompt_tokens_details.cached_tokens
  - Moonshot:非标准位置 choices[].usage.cached_tokens

处理方案:
1. 传递body信息到applyUsagePostProcessing中
2. 拆分智普和Moonshot的解析,并为Moonshot单独写一个解析方法。
2025-12-30 17:38:32 +08:00
Hackerxiao
ae2ca945f3 Merge branch 'QuantumNous:main' into main 2025-12-30 11:44:15 +08:00
wwalt1a
04ea79c429 feat: support HTTP_PROXY environment variable for default HTTP client
- Add Proxy: http.ProxyFromEnvironment to default transport
- Allow users to set global proxy via Docker environment variables
- Per-channel proxy settings still override global proxy
- Fully backward compatible
2025-12-30 03:55:06 +08:00
CaIon
48d358faec feat(adaptor): 新适配百炼多种图片生成模型
- wan2.6系列生图与编辑,适配多图生成计费
- wan2.5系列生图与编辑
- z-image-turbo生图,适配prompt_extend计费
2025-12-29 23:00:17 +08:00
Seefs
8063897998 fix: glm 4.7 finish reason (#2545) 2025-12-29 19:41:15 +08:00
Seefs
923dfbeecb Merge pull request #2544 from seefs001/feature/wan-2.6 2025-12-29 14:53:31 +08:00
Seefs
24d359cf40 feat: Add "wan2.6-i2v" video ratio configuration to Ali adaptor. 2025-12-29 14:13:33 +08:00
Seefs
725d61c5d3 feat: ionet integrate (#2105)
* wip ionet integrate

* wip ionet integrate

* wip ionet integrate

* ollama wip

* wip

* feat: ionet integration & ollama manage

* fix merge conflict

* wip

* fix: test conn cors

* wip

* fix ionet

* fix ionet

* wip

* fix model select

* refactor: Remove `pkg/ionet` test files and update related Go source and web UI model deployment components.

* feat: Enhance model deployment UI with styling improvements, updated text, and a new description component.

* Revert "feat: Enhance model deployment UI with styling improvements, updated text, and a new description component."

This reverts commit 8b75cb5bf0d1a534b339df8c033be9a6c7df7964.
2025-12-28 15:55:35 +08:00
Seefs
1a69a93d20 Merge pull request #2536 from RedwindA/feat/oaiDevRole2Gemini 2025-12-28 15:52:45 +08:00
RedwindA
1de78f8749 feat: map OpenAI developer role to Gemini system instructions 2025-12-27 02:52:33 +08:00
feitianbubu
37a1882798 fix: kling correct fail reason 2025-12-26 16:35:46 +08:00
papersnake
edbd5346e4 fix: dup ratio 2025-12-26 16:25:58 +08:00
papersnake
2c2dfea60f Merge branch 'QuantumNous:main' into fix-claude-haiku 2025-12-26 16:23:34 +08:00
skynono
9aeef6abec feat: support first bind update password (#2520) 2025-12-26 13:59:56 +08:00
Seefs
58db72d459 fix: Fix Openrouter test errors and optimize error messages (#2433)
* fix: Refine openrouter error

* fix: Refine openrouter error

* fix: openrouter test max_output_token

* fix: optimize messages

* fix: maxToken unified to 16

* fix: codex系列模型使用 responses接口

* fix: codex系列模型使用 responses接口

* fix: 状态码非200打印错误信息

* fix: 日志里没有报错的响应体
2025-12-26 13:58:44 +08:00
Calcium-Ion
654bb10b45 Merge pull request #2460 from seefs001/feature/gemini-flash-minial
fix(gemini): handle minimal reasoning effort budget
2025-12-26 13:57:56 +08:00
Seefs
f51b5bb0c8 Merge pull request #2455 from comeback01/french-translation 2025-12-26 13:56:30 +08:00
Calcium-Ion
a4cd84f276 Merge pull request #2450 from seefs001/fix/gemini-system-prompt
fix: 支持传入system_instruction和systemInstruction两种风格系统提示词参数名
2025-12-26 13:54:21 +08:00
Calcium-Ion
c722ddd58b Merge pull request #2512 from seefs001/fix/warning-pass-through-body
fix: add warning for pass through body
2025-12-26 13:52:51 +08:00
Calcium-Ion
88e394a976 Merge pull request #2513 from seefs001/fix/token-auth-bearer
fix: 支持小写bearer和Bearer后带多个空格 && 修复 WSS预扣费错误提取key的问题
2025-12-26 13:51:32 +08:00
Seefs
31a3487139 Merge pull request #2528 from QuantumNous/fix/model-sync-overwrite-empty-missing 2025-12-26 13:49:55 +08:00
Seefs
a07406d97e Merge pull request #2530 from RedwindA/fix/i18n-with-http 2025-12-26 13:49:30 +08:00
RedwindA
f68858121c fix(i18n): disable namespace separator to fix URL display in translations
i18next uses ':' as namespace separator by default, causing URLs like
'https://api.openai.com' to be incorrectly parsed as namespace 'https'
with key '//api.openai.com', resulting in truncated display.

Setting nsSeparator to false fixes this issue since the project doesn't
use multiple namespaces.
2025-12-26 00:10:19 +08:00
t0ng7u
83fbaba768 🚀 fix(model-sync): avoid unnecessary upstream fetch while keeping overwrite updates working
- Only short-circuit when there are no missing models AND no overwrite fields requested
- Preserve overwrite behavior even when the missing-model list is empty
- Always return empty arrays (not null) for list fields to keep API responses stable
- Clarify SyncUpstreamModels behavior in comments (create missing models + optional overwrite updates)
2025-12-25 23:01:09 +08:00
Calcium-Ion
d3c854fbed Merge pull request #2154 from feitianbubu/pr/fix-model-sync
fix: ensure overwrite works correctly when no missing models
2025-12-25 22:34:49 +08:00
Calcium-Ion
97b02685b1 Merge pull request #2475 from seefs001/feature/pyro
feat: pyroscope integrate
2025-12-25 17:54:39 +08:00
Seefs
da1b51ac31 Merge branch 'upstream-main' into feature/pyro 2025-12-25 17:08:02 +08:00
CaIon
f17b3810d6 feat(user): simplify user response structure in JSON output 2025-12-25 15:39:58 +08:00
Calcium-Ion
8206084a77 Merge pull request #2524 from seefs001/fix/revert-model-ratio
fix: revert model ratio
2025-12-25 15:38:36 +08:00
Seefs
559da6362a fix: revert model ratio 2025-12-25 15:37:54 +08:00
Calcium-Ion
0b1a562df9 Merge pull request #2477 from 1420970597/fix/anthropic-cache-billing
fix: 修复 Anthropic 渠道缓存计费错误
2025-12-24 16:59:23 +08:00
Seefs
a0c3d37d66 Merge pull request #2493 from shikaiwei1/patch-1 2025-12-24 16:52:24 +08:00
Seefs
347f2326f3 Merge pull request #2511 from JerryKwan/issue2499 2025-12-24 16:51:51 +08:00
Seefs
14c58aea77 fix: 支持小写bearer和Bearer后带多个空格 && 修复 WSS预扣费错误提取key的问题 2025-12-24 15:52:56 +08:00
Seefs
09f3957362 fix: add warning for pass through body 2025-12-24 15:35:36 +08:00
Jerry
31a79620ba Resolving event mismatch in OpenAI2Claude
add stricter validation for content_block_start corresponding to
tool call
and fix the crash issue when Claude Code is processing tool call
2025-12-24 14:52:39 +08:00
Calcium-Ion
12555a37d3 Merge pull request #2510 from feitianbubu/pr/0e7050dc89c1b761069f5e528d8ecf786e7008ae
修复claudeResponse流式请求空指针Panic
2025-12-24 14:15:51 +08:00
feitianbubu
3652dfdbd5 fix: check claudeResponse delta StopReason nil point 2025-12-24 11:54:23 +08:00
John Chen
dbaba87c39 为Moonshot添加缓存tokens读取逻辑
为Moonshot添加缓存tokens读取逻辑。其与智普V4的逻辑相同,所以共用逻辑
2025-12-22 17:05:16 +08:00
comeback01
f04ed7584a Merge branch 'main' into french-translation 2025-12-20 11:08:07 +01:00
长安
0a2f12c04e fix: 修复 Anthropic 渠道缓存计费错误
## 问题描述

当使用 Anthropic 渠道通过 `/v1/chat/completions` 端点调用且启用缓存功能时,
计费逻辑错误地减去了缓存 tokens,导致严重的收入损失(94.5%)。

## 根本原因

不同 API 的 `prompt_tokens` 定义不同:

- **Anthropic API**: `input_tokens` 字段已经是纯输入 tokens(不包含缓存)
- **OpenAI API**: `prompt_tokens` 字段包含所有 tokens(包含缓存)
- **OpenRouter API**: `prompt_tokens` 字段包含所有 tokens(包含缓存)

当前 `postConsumeQuota` 函数对所有渠道都减去缓存 tokens,这对 Anthropic
渠道是错误的,因为其 `input_tokens` 已经不包含缓存。

## 修复方案

在 `relay/compatible_handler.go` 的 `postConsumeQuota` 函数中,添加渠道类型判断:

```go
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
    baseTokens = baseTokens.Sub(dCacheTokens)
}
```

只对非 Anthropic 渠道减去缓存 tokens。

## 影响分析

###  不受影响的场景

1. **无缓存调用**(所有渠道)
   - cache_tokens = 0
   - 减去 0 = 不减去
   - 结果:完全一致

2. **OpenAI/OpenRouter 渠道 + 缓存**
   - 继续减去缓存(因为 ChannelType != Anthropic)
   - 结果:完全一致

3. **Anthropic 渠道 + /v1/messages 端点**
   - 使用 PostClaudeConsumeQuota(不修改)
   - 结果:完全不受影响

###  修复的场景

4. **Anthropic 渠道 + /v1/chat/completions + 缓存**
   - 修复前:错误地减去缓存,导致 94.5% 收入损失
   - 修复后:不减去缓存,计费正确

## 验证数据

以实际记录 143509 为例:

| 项目 | 修复前 | 修复后 | 差异 |
|------|--------|--------|------|
| Quota | 10,489 | 191,330 | +180,841 |
| 费用 | ¥0.020978 | ¥0.382660 | +¥0.361682 |
| 收入恢复 | - | - | **+1724.1%** |

## 测试建议

1. 测试 Anthropic 渠道 + 缓存场景
2. 测试 OpenAI 渠道 + 缓存场景(确保不受影响)
3. 测试无缓存场景(确保不受影响)

## 相关 Issue

修复 Anthropic 渠道使用 prompt caching 时的计费错误。
2025-12-20 14:17:12 +08:00
Seefs
531dfb2555 docs: document pyroscope env var 2025-12-19 23:16:56 +08:00
Seefs
5ef7247eac docs: document pyroscope env var 2025-12-19 23:03:04 +08:00
Seefs
1168ddf9f9 fix: systemname 2025-12-19 22:27:35 +08:00
Seefs
da24a165d0 fix(gemini): handle minimal reasoning effort budget
- Add minimal case to clampThinkingBudgetByEffort to avoid defaulting to full thinking budget
2025-12-18 08:10:46 +08:00
comeback01
f88fc26150 Refine French translations for UI conciseness
Updated web/src/i18n/locales/fr.json to improve French translations for the user interface.

Removed verbose prefixes like 'Gestion des...' and 'Paramètres de...' to prevent truncation in sidebars and menus.

Harmonized terms for consistency (e.g., 'Tâches', 'Journaux', 'Dessins').

Renamed 'Place du marché' to 'Marché des modèles'.
2025-12-17 12:10:36 +01:00
Seefs
2a511c6ee4 fix: 支持传入system_instruction和systemInstruction两种风格系统提示词参数名 2025-12-16 13:08:58 +08:00
旃蒙
0217ed2f98 fix(task): 修复渠道配置多个key时无法获取任务的问题 2025-12-15 18:15:35 +08:00
Seefs
fcafadc6bb feat: pyroscope integrate 2025-12-13 13:49:38 +08:00
hackerxiao
8e629a2a11 feat: 支持仅使用x-api-key获取anthropic格式的模型列表 注释增加 2025-12-12 17:27:24 +08:00
hackerxiao
2a16c37aab feat: 支持仅使用x-api-key获取anthropic格式的模型列表 2025-12-12 16:53:10 +08:00
Papersnake
681b37d104 feat: support claude-haiku-4-5-20251001 on vertex 2025-12-08 17:28:36 +08:00
feitianbubu
35538ecb3b fix: ensure overwrite works correctly when no missing models 2025-11-03 17:50:00 +08:00
133 changed files with 18557 additions and 1820 deletions

View File

@@ -6,4 +6,5 @@
Makefile
docs
.eslintcache
.gocache
.gocache
/web/node_modules

View File

@@ -9,6 +9,14 @@
# ENABLE_PPROF=true
# 启用调试模式
# DEBUG=true
# Pyroscope 配置
# PYROSCOPE_URL=http://localhost:4040
# PYROSCOPE_APP_NAME=new-api
# PYROSCOPE_BASIC_AUTH_USER=your-user
# PYROSCOPE_BASIC_AUTH_PASSWORD=your-password
# PYROSCOPE_MUTEX_RATE=5
# PYROSCOPE_BLOCK_RATE=5
# HOSTNAME=your-hostname
# 数据库相关配置
# 数据库连接字符串

4
.gitignore vendored
View File

@@ -19,7 +19,11 @@ tiktoken_cache
.gomodcache/
.cache
web/bun.lock
plans
electron/node_modules
electron/dist
data/
.gomodcache/
.gocache-temp
.gopath

View File

@@ -308,6 +308,13 @@ docker run --name new-api -d --restart always \
| `MAX_REQUEST_BODY_MB` | Max request body size (MB, counted **after decompression**; prevents huge requests/zip bombs from exhausting memory). Exceeding it returns `413` | `32` |
| `AZURE_DEFAULT_API_VERSION` | Azure API version | `2025-04-01-preview` |
| `ERROR_LOG_ENABLED` | Error log switch | `false` |
| `PYROSCOPE_URL` | Pyroscope server address | - |
| `PYROSCOPE_APP_NAME` | Pyroscope application name | `new-api` |
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope basic auth user | - |
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope basic auth password | - |
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutex sampling rate | `5` |
| `PYROSCOPE_BLOCK_RATE` | Pyroscope block sampling rate | `5` |
| `HOSTNAME` | Hostname tag for Pyroscope | `new-api` |
📖 **Complete configuration:** [Environment Variables Documentation](https://docs.newapi.pro/en/docs/installation/config-maintenance/environment-variables)

View File

@@ -304,6 +304,13 @@ docker run --name new-api -d --restart always \
| `MAX_REQUEST_BODY_MB` | Taille maximale du corps de requête (Mo, comptée **après décompression** ; évite les requêtes énormes/zip bombs qui saturent la mémoire). Dépassement ⇒ `413` | `32` |
| `AZURE_DEFAULT_API_VERSION` | Version de l'API Azure | `2025-04-01-preview` |
| `ERROR_LOG_ENABLED` | Interrupteur du journal d'erreurs | `false` |
| `PYROSCOPE_URL` | Adresse du serveur Pyroscope | - |
| `PYROSCOPE_APP_NAME` | Nom de l'application Pyroscope | `new-api` |
| `PYROSCOPE_BASIC_AUTH_USER` | Utilisateur Basic Auth Pyroscope | - |
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Mot de passe Basic Auth Pyroscope | - |
| `PYROSCOPE_MUTEX_RATE` | Taux d'échantillonnage mutex Pyroscope | `5` |
| `PYROSCOPE_BLOCK_RATE` | Taux d'échantillonnage block Pyroscope | `5` |
| `HOSTNAME` | Nom d'hôte tagué pour Pyroscope | `new-api` |
📖 **Configuration complète:** [Documentation des variables d'environnement](https://docs.newapi.pro/en/docs/installation/config-maintenance/environment-variables)

View File

@@ -313,6 +313,13 @@ docker run --name new-api -d --restart always \
| `MAX_REQUEST_BODY_MB` | リクエストボディ最大サイズMB、**解凍後**に計測。巨大リクエスト/zip bomb によるメモリ枯渇を防止)。超過時は `413` | `32` |
| `AZURE_DEFAULT_API_VERSION` | Azure APIバージョン | `2025-04-01-preview` |
| `ERROR_LOG_ENABLED` | エラーログスイッチ | `false` |
| `PYROSCOPE_URL` | Pyroscopeサーバーのアドレス | - |
| `PYROSCOPE_APP_NAME` | Pyroscopeアプリ名 | `new-api` |
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope Basic Authユーザー | - |
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope Basic Authパスワード | - |
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutexサンプリング率 | `5` |
| `PYROSCOPE_BLOCK_RATE` | Pyroscope blockサンプリング率 | `5` |
| `HOSTNAME` | Pyroscope用のホスト名タグ | `new-api` |
📖 **完全な設定:** [環境変数ドキュメント](https://docs.newapi.pro/ja/docs/installation/config-maintenance/environment-variables)

View File

@@ -309,6 +309,13 @@ docker run --name new-api -d --restart always \
| `MAX_REQUEST_BODY_MB` | 请求体最大大小MB**解压后**计;防止超大请求/zip bomb 导致内存暴涨),超过将返回 `413` | `32` |
| `AZURE_DEFAULT_API_VERSION` | Azure API 版本 | `2025-04-01-preview` |
| `ERROR_LOG_ENABLED` | 错误日志开关 | `false` |
| `PYROSCOPE_URL` | Pyroscope 服务地址 | - |
| `PYROSCOPE_APP_NAME` | Pyroscope 应用名 | `new-api` |
| `PYROSCOPE_BASIC_AUTH_USER` | Pyroscope Basic Auth 用户名 | - |
| `PYROSCOPE_BASIC_AUTH_PASSWORD` | Pyroscope Basic Auth 密码 | - |
| `PYROSCOPE_MUTEX_RATE` | Pyroscope mutex 采样率 | `5` |
| `PYROSCOPE_BLOCK_RATE` | Pyroscope block 采样率 | `5` |
| `HOSTNAME` | Pyroscope 标签里的主机名 | `new-api` |
📖 **完整配置:** [环境变量文档](https://docs.newapi.pro/zh/docs/installation/config-maintenance/environment-variables)

View File

@@ -40,7 +40,7 @@ func GetRequestBody(c *gin.Context) ([]byte, error) {
}
}
maxMB := constant.MaxRequestBodyMB
if maxMB < 0 {
if maxMB <= 0 {
// no limit
body, err := io.ReadAll(c.Request.Body)
_ = c.Request.Body.Close()

View File

@@ -115,7 +115,7 @@ func InitEnv() {
func initConstantEnv() {
constant.StreamingTimeout = GetEnvOrDefault("STREAMING_TIMEOUT", 300)
constant.DifyDebug = GetEnvOrDefaultBool("DIFY_DEBUG", true)
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 64)
constant.StreamScannerMaxBufferMB = GetEnvOrDefault("STREAM_SCANNER_MAX_BUFFER_MB", 64)
// MaxRequestBodyMB 请求体最大大小(解压后),用于防止超大请求/zip bomb导致内存暴涨
constant.MaxRequestBodyMB = GetEnvOrDefault("MAX_REQUEST_BODY_MB", 64)

56
common/pyro.go Normal file
View File

@@ -0,0 +1,56 @@
package common
import (
"runtime"
"github.com/grafana/pyroscope-go"
)
func StartPyroScope() error {
pyroscopeUrl := GetEnvOrDefaultString("PYROSCOPE_URL", "")
if pyroscopeUrl == "" {
return nil
}
pyroscopeAppName := GetEnvOrDefaultString("PYROSCOPE_APP_NAME", "new-api")
pyroscopeBasicAuthUser := GetEnvOrDefaultString("PYROSCOPE_BASIC_AUTH_USER", "")
pyroscopeBasicAuthPassword := GetEnvOrDefaultString("PYROSCOPE_BASIC_AUTH_PASSWORD", "")
pyroscopeHostname := GetEnvOrDefaultString("HOSTNAME", "new-api")
mutexRate := GetEnvOrDefault("PYROSCOPE_MUTEX_RATE", 5)
blockRate := GetEnvOrDefault("PYROSCOPE_BLOCK_RATE", 5)
runtime.SetMutexProfileFraction(mutexRate)
runtime.SetBlockProfileRate(blockRate)
_, err := pyroscope.Start(pyroscope.Config{
ApplicationName: pyroscopeAppName,
ServerAddress: pyroscopeUrl,
BasicAuthUser: pyroscopeBasicAuthUser,
BasicAuthPassword: pyroscopeBasicAuthPassword,
Logger: nil,
Tags: map[string]string{"hostname": pyroscopeHostname},
ProfileTypes: []pyroscope.ProfileType{
pyroscope.ProfileCPU,
pyroscope.ProfileAllocObjects,
pyroscope.ProfileAllocSpace,
pyroscope.ProfileInuseObjects,
pyroscope.ProfileInuseSpace,
pyroscope.ProfileGoroutines,
pyroscope.ProfileMutexCount,
pyroscope.ProfileMutexDuration,
pyroscope.ProfileBlockCount,
pyroscope.ProfileBlockDuration,
},
})
if err != nil {
return err
}
return nil
}

View File

@@ -40,13 +40,6 @@ type testResult struct {
newAPIError *types.NewAPIError
}
// testChannel executes a test request against the given channel using the provided testModel and optional endpointType,
// and returns a testResult containing the test context and any encountered error information.
// It selects or derives a model when testModel is empty, auto-detects the request endpoint (chat, responses, embeddings, images, rerank) when endpointType is not specified,
// converts and relays the request to the upstream adapter, and parses the upstream response to collect usage and pricing information.
// On upstream responses that indicate the chat/completions `messages` parameter is unsupported and endpointType was not specified, it will retry the test using the Responses API.
// The function records consumption logs and returns a testResult with a populated context on success, or with localErr/newAPIError set on failure;
// for channel types that are not supported for testing it returns a localErr explaining that the channel test is not supported.
func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
tik := time.Now()
var unsupportedTestChannelTypes = []int{
@@ -82,8 +75,6 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
originTestModel := testModel
requestPath := "/v1/chat/completions"
// 如果指定了端点类型,使用指定的端点类型
@@ -93,10 +84,6 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
} else {
// 如果没有指定端点类型,使用原有的自动检测逻辑
if common.IsOpenAIResponseOnlyModel(testModel) {
requestPath = "/v1/responses"
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(testModel), "embedding") ||
strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
@@ -110,6 +97,11 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
requestPath = "/v1/images/generations"
}
// responses-only models
if strings.Contains(strings.ToLower(testModel), "codex") {
requestPath = "/v1/responses"
}
}
c.Request = &http.Request{
@@ -189,7 +181,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
request := buildTestRequest(testModel, endpointType)
request := buildTestRequest(testModel, endpointType, channel)
info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
@@ -332,13 +324,16 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
// 自动检测模式下,如果上游不支持 chat.completions 的 messages 参数,尝试切换到 Responses API 再测一次。
if endpointType == "" && requestPath == "/v1/chat/completions" && err != nil {
lowerErr := strings.ToLower(err.Error())
if strings.Contains(lowerErr, "unsupported parameter") && strings.Contains(lowerErr, "messages") {
return testChannel(channel, originTestModel, string(constant.EndpointTypeOpenAIResponse))
}
}
common.SysError(fmt.Sprintf(
"channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v",
channel.Id,
channel.Name,
channel.Type,
testModel,
endpointType,
httpResp.StatusCode,
err,
))
return testResult{
context: c,
localErr: err,
@@ -409,8 +404,7 @@ func testChannel(channel *model.Channel, testModel string, endpointType string)
}
}
// for embedding models, and otherwise a chat/completion request with model-specific token limit heuristics.
func buildTestRequest(model string, endpointType string) dto.Request {
func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
// 根据端点类型构建不同的测试请求
if endpointType != "" {
switch constant.EndpointType(endpointType) {
@@ -438,16 +432,13 @@ func buildTestRequest(model string, endpointType string) dto.Request {
}
case constant.EndpointTypeOpenAIResponse:
// 返回 OpenAIResponsesRequest
maxOutputTokens := uint(10)
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
MaxOutputTokens: maxOutputTokens,
Stream: true,
Model: model,
Input: json.RawMessage("\"hi\""),
}
case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
// 返回 GeneralOpenAIRequest
maxTokens := uint(10)
maxTokens := uint(16)
if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
maxTokens = 3000
}
@@ -466,16 +457,6 @@ func buildTestRequest(model string, endpointType string) dto.Request {
}
// 自动检测逻辑(保持原有行为)
if common.IsOpenAIResponseOnlyModel(model) {
maxOutputTokens := uint(10)
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
MaxOutputTokens: maxOutputTokens,
Stream: true,
}
}
// 先判断是否为 Embedding 模型
if strings.Contains(strings.ToLower(model), "embedding") ||
strings.HasPrefix(model, "m3e") ||
@@ -487,6 +468,14 @@ func buildTestRequest(model string, endpointType string) dto.Request {
}
}
// Responses-only models (e.g. codex series)
if strings.Contains(strings.ToLower(model), "codex") {
return &dto.OpenAIResponsesRequest{
Model: model,
Input: json.RawMessage("\"hi\""),
}
}
// Chat/Completion 请求 - 返回 GeneralOpenAIRequest
testRequest := &dto.GeneralOpenAIRequest{
Model: model,
@@ -500,7 +489,7 @@ func buildTestRequest(model string, endpointType string) dto.Request {
}
if strings.HasPrefix(model, "o") {
testRequest.MaxCompletionTokens = 10
testRequest.MaxCompletionTokens = 16
} else if strings.Contains(model, "thinking") {
if !strings.Contains(model, "claude") {
testRequest.MaxTokens = 50
@@ -508,7 +497,7 @@ func buildTestRequest(model string, endpointType string) dto.Request {
} else if strings.Contains(model, "gemini") {
testRequest.MaxTokens = 3000
} else {
testRequest.MaxTokens = 10
testRequest.MaxTokens = 16
}
return testRequest
@@ -674,4 +663,4 @@ func AutomaticallyTestChannels() {
}
}
})
}
}

View File

@@ -11,16 +11,18 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel/ollama"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Metadata map[string]any `json:"metadata,omitempty"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
@@ -207,6 +209,57 @@ func FetchUpstreamModels(c *gin.Context) {
baseURL = channel.GetBaseURL()
}
// 对于 Ollama 渠道,使用特殊处理
if channel.Type == constant.ChannelTypeOllama {
key := strings.Split(channel.Key, "\n")[0]
models, err := ollama.FetchOllamaModels(baseURL, key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
})
return
}
result := OpenAIModelsResponse{
Data: make([]OpenAIModel, 0, len(models)),
}
for _, modelInfo := range models {
metadata := map[string]any{}
if modelInfo.Size > 0 {
metadata["size"] = modelInfo.Size
}
if modelInfo.Digest != "" {
metadata["digest"] = modelInfo.Digest
}
if modelInfo.ModifiedAt != "" {
metadata["modified_at"] = modelInfo.ModifiedAt
}
details := modelInfo.Details
if details.ParentModel != "" || details.Format != "" || details.Family != "" || len(details.Families) > 0 || details.ParameterSize != "" || details.QuantizationLevel != "" {
metadata["details"] = modelInfo.Details
}
if len(metadata) == 0 {
metadata = nil
}
result.Data = append(result.Data, OpenAIModel{
ID: modelInfo.Name,
Object: "model",
Created: 0,
OwnedBy: "ollama",
Metadata: metadata,
})
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": result.Data,
})
return
}
var url string
switch channel.Type {
case constant.ChannelTypeGemini:
@@ -917,9 +970,6 @@ func UpdateChannel(c *gin.Context) {
// 单个JSON密钥
newKeys = []string{channel.Key}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
} else {
// 普通渠道的处理
inputKeys := strings.Split(channel.Key, "\n")
@@ -929,10 +979,31 @@ func UpdateChannel(c *gin.Context) {
newKeys = append(newKeys, key)
}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
seen := make(map[string]struct{}, len(existingKeys)+len(newKeys))
for _, key := range existingKeys {
normalized := strings.TrimSpace(key)
if normalized == "" {
continue
}
seen[normalized] = struct{}{}
}
dedupedNewKeys := make([]string, 0, len(newKeys))
for _, key := range newKeys {
normalized := strings.TrimSpace(key)
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
dedupedNewKeys = append(dedupedNewKeys, normalized)
}
allKeys := append(existingKeys, dedupedNewKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
case "replace":
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
@@ -975,6 +1046,32 @@ func FetchModels(c *gin.Context) {
baseURL = constant.ChannelBaseURLs[req.Type]
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
key = strings.Split(key, "\n")[0]
if req.Type == constant.ChannelTypeOllama {
models, err := ollama.FetchOllamaModels(baseURL, key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()),
})
return
}
names := make([]string, 0, len(models))
for _, modelInfo := range models {
names = append(names, modelInfo.Name)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": names,
})
return
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
@@ -987,10 +1084,6 @@ func FetchModels(c *gin.Context) {
return
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
// If the key contains a line break, only take the first part.
key = strings.Split(key, "\n")[0]
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
@@ -1640,3 +1733,262 @@ func ManageMultiKeys(c *gin.Context) {
return
}
}
// OllamaPullModel 拉取 Ollama 模型
func OllamaPullModel(c *gin.Context) {
var req struct {
ChannelID int `json:"channel_id"`
ModelName string `json:"model_name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request parameters",
})
return
}
if req.ChannelID == 0 || req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Channel ID and model name are required",
})
return
}
// 获取渠道信息
channel, err := model.GetChannelById(req.ChannelID, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
// 检查是否是 Ollama 渠道
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
key := strings.Split(channel.Key, "\n")[0]
err = ollama.PullOllamaModel(baseURL, key, req.ModelName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": fmt.Sprintf("Failed to pull model: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
})
}
// OllamaPullModelStream 流式拉取 Ollama 模型
func OllamaPullModelStream(c *gin.Context) {
var req struct {
ChannelID int `json:"channel_id"`
ModelName string `json:"model_name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request parameters",
})
return
}
if req.ChannelID == 0 || req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Channel ID and model name are required",
})
return
}
// 获取渠道信息
channel, err := model.GetChannelById(req.ChannelID, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
// 检查是否是 Ollama 渠道
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
// 设置 SSE 头部
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
key := strings.Split(channel.Key, "\n")[0]
// 创建进度回调函数
progressCallback := func(progress ollama.OllamaPullResponse) {
data, _ := json.Marshal(progress)
fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
c.Writer.Flush()
}
// 执行拉取
err = ollama.PullOllamaModelStream(baseURL, key, req.ModelName, progressCallback)
if err != nil {
errorData, _ := json.Marshal(gin.H{
"error": err.Error(),
})
fmt.Fprintf(c.Writer, "data: %s\n\n", string(errorData))
} else {
successData, _ := json.Marshal(gin.H{
"message": fmt.Sprintf("Model %s pulled successfully", req.ModelName),
})
fmt.Fprintf(c.Writer, "data: %s\n\n", string(successData))
}
// 发送结束标志
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
c.Writer.Flush()
}
// OllamaDeleteModel 删除 Ollama 模型
func OllamaDeleteModel(c *gin.Context) {
var req struct {
ChannelID int `json:"channel_id"`
ModelName string `json:"model_name"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request parameters",
})
return
}
if req.ChannelID == 0 || req.ModelName == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Channel ID and model name are required",
})
return
}
// 获取渠道信息
channel, err := model.GetChannelById(req.ChannelID, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
// 检查是否是 Ollama 渠道
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
key := strings.Split(channel.Key, "\n")[0]
err = ollama.DeleteOllamaModel(baseURL, key, req.ModelName)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": fmt.Sprintf("Failed to delete model: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("Model %s deleted successfully", req.ModelName),
})
}
// OllamaVersion 获取 Ollama 服务版本信息
func OllamaVersion(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid channel id",
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Channel not found",
})
return
}
if channel.Type != constant.ChannelTypeOllama {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "This operation is only supported for Ollama channels",
})
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
key := strings.Split(channel.Key, "\n")[0]
version, err := ollama.FetchOllamaVersion(baseURL, key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取Ollama版本失败: %s", err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"version": version,
},
})
}

72
controller/checkin.go Normal file
View File

@@ -0,0 +1,72 @@
package controller
import (
"fmt"
"net/http"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
// GetCheckinStatus 获取用户签到状态和历史记录
func GetCheckinStatus(c *gin.Context) {
setting := operation_setting.GetCheckinSetting()
if !setting.Enabled {
common.ApiErrorMsg(c, "签到功能未启用")
return
}
userId := c.GetInt("id")
// 获取月份参数,默认为当前月份
month := c.DefaultQuery("month", time.Now().Format("2006-01"))
stats, err := model.GetUserCheckinStats(userId, month)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"enabled": setting.Enabled,
"min_quota": setting.MinQuota,
"max_quota": setting.MaxQuota,
"stats": stats,
},
})
}
// DoCheckin 执行用户签到
func DoCheckin(c *gin.Context) {
setting := operation_setting.GetCheckinSetting()
if !setting.Enabled {
common.ApiErrorMsg(c, "签到功能未启用")
return
}
userId := c.GetInt("id")
checkin, err := model.UserCheckin(userId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("用户签到,获得额度 %s", logger.LogQuota(checkin.QuotaAwarded)))
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "签到成功",
"data": gin.H{
"quota_awarded": checkin.QuotaAwarded,
"checkin_date": checkin.CheckinDate},
})
}

810
controller/deployment.go Normal file
View File

@@ -0,0 +1,810 @@
package controller
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/ionet"
"github.com/gin-gonic/gin"
)
func getIoAPIKey(c *gin.Context) (string, bool) {
common.OptionMapRWMutex.RLock()
enabled := common.OptionMap["model_deployment.ionet.enabled"] == "true"
apiKey := common.OptionMap["model_deployment.ionet.api_key"]
common.OptionMapRWMutex.RUnlock()
if !enabled || strings.TrimSpace(apiKey) == "" {
common.ApiErrorMsg(c, "io.net model deployment is not enabled or api key missing")
return "", false
}
return apiKey, true
}
func GetModelDeploymentSettings(c *gin.Context) {
common.OptionMapRWMutex.RLock()
enabled := common.OptionMap["model_deployment.ionet.enabled"] == "true"
hasAPIKey := strings.TrimSpace(common.OptionMap["model_deployment.ionet.api_key"]) != ""
common.OptionMapRWMutex.RUnlock()
common.ApiSuccess(c, gin.H{
"provider": "io.net",
"enabled": enabled,
"configured": hasAPIKey,
"can_connect": enabled && hasAPIKey,
})
}
func getIoClient(c *gin.Context) (*ionet.Client, bool) {
apiKey, ok := getIoAPIKey(c)
if !ok {
return nil, false
}
return ionet.NewClient(apiKey), true
}
func getIoEnterpriseClient(c *gin.Context) (*ionet.Client, bool) {
apiKey, ok := getIoAPIKey(c)
if !ok {
return nil, false
}
return ionet.NewEnterpriseClient(apiKey), true
}
func TestIoNetConnection(c *gin.Context) {
var req struct {
APIKey string `json:"api_key"`
}
rawBody, err := c.GetRawData()
if err != nil {
common.ApiError(c, err)
return
}
if len(bytes.TrimSpace(rawBody)) > 0 {
if err := json.Unmarshal(rawBody, &req); err != nil {
common.ApiErrorMsg(c, "invalid request payload")
return
}
}
apiKey := strings.TrimSpace(req.APIKey)
if apiKey == "" {
common.OptionMapRWMutex.RLock()
storedKey := strings.TrimSpace(common.OptionMap["model_deployment.ionet.api_key"])
common.OptionMapRWMutex.RUnlock()
if storedKey == "" {
common.ApiErrorMsg(c, "api_key is required")
return
}
apiKey = storedKey
}
client := ionet.NewEnterpriseClient(apiKey)
result, err := client.GetMaxGPUsPerContainer()
if err != nil {
if apiErr, ok := err.(*ionet.APIError); ok {
message := strings.TrimSpace(apiErr.Message)
if message == "" {
message = "failed to validate api key"
}
common.ApiErrorMsg(c, message)
return
}
common.ApiError(c, err)
return
}
totalHardware := 0
totalAvailable := 0
if result != nil {
totalHardware = len(result.Hardware)
totalAvailable = result.Total
if totalAvailable == 0 {
for _, hw := range result.Hardware {
totalAvailable += hw.Available
}
}
}
common.ApiSuccess(c, gin.H{
"hardware_count": totalHardware,
"total_available": totalAvailable,
})
}
func requireDeploymentID(c *gin.Context) (string, bool) {
deploymentID := strings.TrimSpace(c.Param("id"))
if deploymentID == "" {
common.ApiErrorMsg(c, "deployment ID is required")
return "", false
}
return deploymentID, true
}
func requireContainerID(c *gin.Context) (string, bool) {
containerID := strings.TrimSpace(c.Param("container_id"))
if containerID == "" {
common.ApiErrorMsg(c, "container ID is required")
return "", false
}
return containerID, true
}
func mapIoNetDeployment(d ionet.Deployment) map[string]interface{} {
var created int64
if d.CreatedAt.IsZero() {
created = time.Now().Unix()
} else {
created = d.CreatedAt.Unix()
}
timeRemainingHours := d.ComputeMinutesRemaining / 60
timeRemainingMins := d.ComputeMinutesRemaining % 60
var timeRemaining string
if timeRemainingHours > 0 {
timeRemaining = fmt.Sprintf("%d hour %d minutes", timeRemainingHours, timeRemainingMins)
} else if timeRemainingMins > 0 {
timeRemaining = fmt.Sprintf("%d minutes", timeRemainingMins)
} else {
timeRemaining = "completed"
}
hardwareInfo := fmt.Sprintf("%s %s x%d", d.BrandName, d.HardwareName, d.HardwareQuantity)
return map[string]interface{}{
"id": d.ID,
"deployment_name": d.Name,
"container_name": d.Name,
"status": strings.ToLower(d.Status),
"type": "Container",
"time_remaining": timeRemaining,
"time_remaining_minutes": d.ComputeMinutesRemaining,
"hardware_info": hardwareInfo,
"hardware_name": d.HardwareName,
"brand_name": d.BrandName,
"hardware_quantity": d.HardwareQuantity,
"completed_percent": d.CompletedPercent,
"compute_minutes_served": d.ComputeMinutesServed,
"compute_minutes_remaining": d.ComputeMinutesRemaining,
"created_at": created,
"updated_at": created,
"model_name": "",
"model_version": "",
"instance_count": d.HardwareQuantity,
"resource_config": map[string]interface{}{
"cpu": "",
"memory": "",
"gpu": strconv.Itoa(d.HardwareQuantity),
},
"description": "",
"provider": "io.net",
}
}
func computeStatusCounts(total int, deployments []ionet.Deployment) map[string]int64 {
counts := map[string]int64{
"all": int64(total),
}
for _, status := range []string{"running", "completed", "failed", "deployment requested", "termination requested", "destroyed"} {
counts[status] = 0
}
for _, d := range deployments {
status := strings.ToLower(strings.TrimSpace(d.Status))
counts[status] = counts[status] + 1
}
return counts
}
func GetAllDeployments(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
status := c.Query("status")
opts := &ionet.ListDeploymentsOptions{
Status: strings.ToLower(strings.TrimSpace(status)),
Page: pageInfo.GetPage(),
PageSize: pageInfo.GetPageSize(),
SortBy: "created_at",
SortOrder: "desc",
}
dl, err := client.ListDeployments(opts)
if err != nil {
common.ApiError(c, err)
return
}
items := make([]map[string]interface{}, 0, len(dl.Deployments))
for _, d := range dl.Deployments {
items = append(items, mapIoNetDeployment(d))
}
data := gin.H{
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"total": dl.Total,
"items": items,
"status_counts": computeStatusCounts(dl.Total, dl.Deployments),
}
common.ApiSuccess(c, data)
}
func SearchDeployments(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
status := strings.ToLower(strings.TrimSpace(c.Query("status")))
keyword := strings.TrimSpace(c.Query("keyword"))
dl, err := client.ListDeployments(&ionet.ListDeploymentsOptions{
Status: status,
Page: pageInfo.GetPage(),
PageSize: pageInfo.GetPageSize(),
SortBy: "created_at",
SortOrder: "desc",
})
if err != nil {
common.ApiError(c, err)
return
}
filtered := make([]ionet.Deployment, 0, len(dl.Deployments))
if keyword == "" {
filtered = dl.Deployments
} else {
kw := strings.ToLower(keyword)
for _, d := range dl.Deployments {
if strings.Contains(strings.ToLower(d.Name), kw) {
filtered = append(filtered, d)
}
}
}
items := make([]map[string]interface{}, 0, len(filtered))
for _, d := range filtered {
items = append(items, mapIoNetDeployment(d))
}
total := dl.Total
if keyword != "" {
total = len(filtered)
}
data := gin.H{
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"total": total,
"items": items,
}
common.ApiSuccess(c, data)
}
func GetDeployment(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
details, err := client.GetDeployment(deploymentID)
if err != nil {
common.ApiError(c, err)
return
}
data := map[string]interface{}{
"id": details.ID,
"deployment_name": details.ID,
"model_name": "",
"model_version": "",
"status": strings.ToLower(details.Status),
"instance_count": details.TotalContainers,
"hardware_id": details.HardwareID,
"resource_config": map[string]interface{}{
"cpu": "",
"memory": "",
"gpu": strconv.Itoa(details.TotalGPUs),
},
"created_at": details.CreatedAt.Unix(),
"updated_at": details.CreatedAt.Unix(),
"description": "",
"amount_paid": details.AmountPaid,
"completed_percent": details.CompletedPercent,
"gpus_per_container": details.GPUsPerContainer,
"total_gpus": details.TotalGPUs,
"total_containers": details.TotalContainers,
"hardware_name": details.HardwareName,
"brand_name": details.BrandName,
"compute_minutes_served": details.ComputeMinutesServed,
"compute_minutes_remaining": details.ComputeMinutesRemaining,
"locations": details.Locations,
"container_config": details.ContainerConfig,
}
common.ApiSuccess(c, data)
}
func UpdateDeploymentName(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
var req struct {
Name string `json:"name" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
updateReq := &ionet.UpdateClusterNameRequest{
Name: strings.TrimSpace(req.Name),
}
if updateReq.Name == "" {
common.ApiErrorMsg(c, "deployment name cannot be empty")
return
}
available, err := client.CheckClusterNameAvailability(updateReq.Name)
if err != nil {
common.ApiError(c, fmt.Errorf("failed to check name availability: %w", err))
return
}
if !available {
common.ApiErrorMsg(c, "deployment name is not available, please choose a different name")
return
}
resp, err := client.UpdateClusterName(deploymentID, updateReq)
if err != nil {
common.ApiError(c, err)
return
}
data := gin.H{
"status": resp.Status,
"message": resp.Message,
"id": deploymentID,
"name": updateReq.Name,
}
common.ApiSuccess(c, data)
}
func UpdateDeployment(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
var req ionet.UpdateDeploymentRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
resp, err := client.UpdateDeployment(deploymentID, &req)
if err != nil {
common.ApiError(c, err)
return
}
data := gin.H{
"status": resp.Status,
"deployment_id": resp.DeploymentID,
}
common.ApiSuccess(c, data)
}
func ExtendDeployment(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
var req ionet.ExtendDurationRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
details, err := client.ExtendDeployment(deploymentID, &req)
if err != nil {
common.ApiError(c, err)
return
}
data := mapIoNetDeployment(ionet.Deployment{
ID: details.ID,
Status: details.Status,
Name: deploymentID,
CompletedPercent: float64(details.CompletedPercent),
HardwareQuantity: details.TotalGPUs,
BrandName: details.BrandName,
HardwareName: details.HardwareName,
ComputeMinutesServed: details.ComputeMinutesServed,
ComputeMinutesRemaining: details.ComputeMinutesRemaining,
CreatedAt: details.CreatedAt,
})
common.ApiSuccess(c, data)
}
func DeleteDeployment(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
resp, err := client.DeleteDeployment(deploymentID)
if err != nil {
common.ApiError(c, err)
return
}
data := gin.H{
"status": resp.Status,
"deployment_id": resp.DeploymentID,
"message": "Deployment termination requested successfully",
}
common.ApiSuccess(c, data)
}
func CreateDeployment(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
var req ionet.DeploymentRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
resp, err := client.DeployContainer(&req)
if err != nil {
common.ApiError(c, err)
return
}
data := gin.H{
"deployment_id": resp.DeploymentID,
"status": resp.Status,
"message": "Deployment created successfully",
}
common.ApiSuccess(c, data)
}
func GetHardwareTypes(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
hardwareTypes, totalAvailable, err := client.ListHardwareTypes()
if err != nil {
common.ApiError(c, err)
return
}
data := gin.H{
"hardware_types": hardwareTypes,
"total": len(hardwareTypes),
"total_available": totalAvailable,
}
common.ApiSuccess(c, data)
}
func GetLocations(c *gin.Context) {
client, ok := getIoClient(c)
if !ok {
return
}
locationsResp, err := client.ListLocations()
if err != nil {
common.ApiError(c, err)
return
}
total := locationsResp.Total
if total == 0 {
total = len(locationsResp.Locations)
}
data := gin.H{
"locations": locationsResp.Locations,
"total": total,
}
common.ApiSuccess(c, data)
}
func GetAvailableReplicas(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
hardwareIDStr := c.Query("hardware_id")
gpuCountStr := c.Query("gpu_count")
if hardwareIDStr == "" {
common.ApiErrorMsg(c, "hardware_id parameter is required")
return
}
hardwareID, err := strconv.Atoi(hardwareIDStr)
if err != nil || hardwareID <= 0 {
common.ApiErrorMsg(c, "invalid hardware_id parameter")
return
}
gpuCount := 1
if gpuCountStr != "" {
if parsed, err := strconv.Atoi(gpuCountStr); err == nil && parsed > 0 {
gpuCount = parsed
}
}
replicas, err := client.GetAvailableReplicas(hardwareID, gpuCount)
if err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, replicas)
}
func GetPriceEstimation(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
var req ionet.PriceEstimationRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiError(c, err)
return
}
priceResp, err := client.GetPriceEstimation(&req)
if err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, priceResp)
}
func CheckClusterNameAvailability(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
clusterName := strings.TrimSpace(c.Query("name"))
if clusterName == "" {
common.ApiErrorMsg(c, "name parameter is required")
return
}
available, err := client.CheckClusterNameAvailability(clusterName)
if err != nil {
common.ApiError(c, err)
return
}
data := gin.H{
"available": available,
"name": clusterName,
}
common.ApiSuccess(c, data)
}
func GetDeploymentLogs(c *gin.Context) {
client, ok := getIoClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
containerID := c.Query("container_id")
if containerID == "" {
common.ApiErrorMsg(c, "container_id parameter is required")
return
}
level := c.Query("level")
stream := c.Query("stream")
cursor := c.Query("cursor")
limitStr := c.Query("limit")
follow := c.Query("follow") == "true"
var limit int = 100
if limitStr != "" {
if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 {
limit = parsedLimit
if limit > 1000 {
limit = 1000
}
}
}
opts := &ionet.GetLogsOptions{
Level: level,
Stream: stream,
Limit: limit,
Cursor: cursor,
Follow: follow,
}
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse(time.RFC3339, startTime); err == nil {
opts.StartTime = &t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse(time.RFC3339, endTime); err == nil {
opts.EndTime = &t
}
}
rawLogs, err := client.GetContainerLogsRaw(deploymentID, containerID, opts)
if err != nil {
common.ApiError(c, err)
return
}
common.ApiSuccess(c, rawLogs)
}
func ListDeploymentContainers(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
containers, err := client.ListContainers(deploymentID)
if err != nil {
common.ApiError(c, err)
return
}
items := make([]map[string]interface{}, 0)
if containers != nil {
items = make([]map[string]interface{}, 0, len(containers.Workers))
for _, ctr := range containers.Workers {
events := make([]map[string]interface{}, 0, len(ctr.ContainerEvents))
for _, event := range ctr.ContainerEvents {
events = append(events, map[string]interface{}{
"time": event.Time.Unix(),
"message": event.Message,
})
}
items = append(items, map[string]interface{}{
"container_id": ctr.ContainerID,
"device_id": ctr.DeviceID,
"status": strings.ToLower(strings.TrimSpace(ctr.Status)),
"hardware": ctr.Hardware,
"brand_name": ctr.BrandName,
"created_at": ctr.CreatedAt.Unix(),
"uptime_percent": ctr.UptimePercent,
"gpus_per_container": ctr.GPUsPerContainer,
"public_url": ctr.PublicURL,
"events": events,
})
}
}
response := gin.H{
"total": 0,
"containers": items,
}
if containers != nil {
response["total"] = containers.Total
}
common.ApiSuccess(c, response)
}
func GetContainerDetails(c *gin.Context) {
client, ok := getIoEnterpriseClient(c)
if !ok {
return
}
deploymentID, ok := requireDeploymentID(c)
if !ok {
return
}
containerID, ok := requireContainerID(c)
if !ok {
return
}
details, err := client.GetContainerDetails(deploymentID, containerID)
if err != nil {
common.ApiError(c, err)
return
}
if details == nil {
common.ApiErrorMsg(c, "container details not found")
return
}
events := make([]map[string]interface{}, 0, len(details.ContainerEvents))
for _, event := range details.ContainerEvents {
events = append(events, map[string]interface{}{
"time": event.Time.Unix(),
"message": event.Message,
})
}
data := gin.H{
"deployment_id": deploymentID,
"container_id": details.ContainerID,
"device_id": details.DeviceID,
"status": strings.ToLower(strings.TrimSpace(details.Status)),
"hardware": details.Hardware,
"brand_name": details.BrandName,
"created_at": details.CreatedAt.Unix(),
"uptime_percent": details.UptimePercent,
"gpus_per_container": details.GPUsPerContainer,
"public_url": details.PublicURL,
"events": events,
}
common.ApiSuccess(c, data)
}

View File

@@ -114,6 +114,7 @@ func GetStatus(c *gin.Context) {
"setup": constant.Setup,
"user_agreement_enabled": legalSetting.UserAgreement != "",
"privacy_policy_enabled": legalSetting.PrivacyPolicy != "",
"checkin_enabled": operation_setting.GetCheckinSetting().Enabled,
}
// 根据启用状态注入可选内容

View File

@@ -249,7 +249,9 @@ func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, v
return 0
}
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
// SyncUpstreamModels 同步上游模型与供应商
// - 默认仅创建「未配置模型」
// - 可通过 overwrite 选择性覆盖更新本地已有模型的字段前提sync_official <> 0
func SyncUpstreamModels(c *gin.Context) {
var req syncRequest
// 允许空体
@@ -260,12 +262,26 @@ func SyncUpstreamModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
if len(missing) == 0 {
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
"created_models": 0,
"created_vendors": 0,
"skipped_models": []string{},
}})
// 若既无缺失模型需要创建,也未指定覆盖更新字段,则无需请求上游数据,直接返回
if len(missing) == 0 && len(req.Overwrite) == 0 {
modelsURL, vendorsURL := getUpstreamURLs(req.Locale)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"created_models": 0,
"created_vendors": 0,
"updated_models": 0,
"skipped_models": []string{},
"created_list": []string{},
"updated_list": []string{},
"source": gin.H{
"locale": req.Locale,
"models_url": modelsURL,
"vendors_url": vendorsURL,
},
},
})
return
}
@@ -315,9 +331,9 @@ func SyncUpstreamModels(c *gin.Context) {
createdModels := 0
createdVendors := 0
updatedModels := 0
var skipped []string
var createdList []string
var updatedList []string
skipped := make([]string, 0)
createdList := make([]string, 0)
updatedList := make([]string, 0)
// 本地缓存vendorName -> id
vendorIDCache := make(map[string]int)

View File

@@ -20,7 +20,11 @@ func GetOptions(c *gin.Context) {
var options []*model.Option
common.OptionMapRWMutex.Lock()
for k, v := range common.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
if strings.HasSuffix(k, "Token") ||
strings.HasSuffix(k, "Secret") ||
strings.HasSuffix(k, "Key") ||
strings.HasSuffix(k, "secret") ||
strings.HasSuffix(k, "api_key") {
continue
}
options = append(options, &model.Option{

View File

@@ -74,7 +74,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
return fmt.Errorf("task %s not found", taskId)
}
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
key := channel.Key
privateData := task.PrivateData
if privateData.Key != "" {
key = privateData.Key
}
resp, err := adaptor.FetchTask(baseURL, key, map[string]any{
"task_id": taskId,
"action": task.Action,
}, proxy)

View File

@@ -1,6 +1,7 @@
package controller
import (
"fmt"
"net/http"
"strconv"
"strings"
@@ -149,6 +150,24 @@ func AddToken(c *gin.Context) {
})
return
}
// 非无限额度时,检查额度值是否超出有效范围
if !token.UnlimitedQuota {
if token.RemainQuota < 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "额度值不能为负数",
})
return
}
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
if token.RemainQuota > maxQuotaValue {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
})
return
}
}
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) {
})
return
}
if !token.UnlimitedQuota {
if token.RemainQuota < 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "额度值不能为负数",
})
return
}
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
if token.RemainQuota > maxQuotaValue {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
})
return
}
}
cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil {
common.ApiError(c, err)
@@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) {
"message": "",
"data": cleanToken,
})
return
}
type TokenBatch struct {

View File

@@ -110,18 +110,17 @@ func setupLogin(user *model.User, c *gin.Context) {
})
return
}
cleanUser := model.User{
Id: user.Id,
Username: user.Username,
DisplayName: user.DisplayName,
Role: user.Role,
Status: user.Status,
Group: user.Group,
}
c.JSON(http.StatusOK, gin.H{
"message": "",
"success": true,
"data": cleanUser,
"data": map[string]any{
"id": user.Id,
"username": user.Username,
"display_name": user.DisplayName,
"role": user.Role,
"status": user.Status,
"group": user.Group,
},
})
}
@@ -764,7 +763,10 @@ func checkUpdatePassword(originalPassword string, newPassword string, userId int
if err != nil {
return
}
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) {
// 密码不为空,需要验证原密码
// 支持第一次账号绑定时原密码为空的情况
if !common.ValidatePasswordAndHash(originalPassword, currentUser.Password) && currentUser.Password != "" {
err = fmt.Errorf("原密码错误")
return
}

7
docs/ionet-client.md Normal file
View File

@@ -0,0 +1,7 @@
Request URL
https://api.io.solutions/v1/io-cloud/clusters/654fc0a9-0d4a-4db4-9b95-3f56189348a2/update-name
Request Method
PUT
{"status":"succeeded","message":"Cluster name updated successfully"}

View File

@@ -26,6 +26,7 @@ type GeneralErrorResponse struct {
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Metadata json.RawMessage `json:"metadata,omitempty"`
Header struct {
Message string `json:"message"`
} `json:"header"`

View File

@@ -22,6 +22,27 @@ type GeminiChatRequest struct {
CachedContent string `json:"cachedContent,omitempty"`
}
// UnmarshalJSON allows GeminiChatRequest to accept both snake_case and camelCase fields.
func (r *GeminiChatRequest) UnmarshalJSON(data []byte) error {
type Alias GeminiChatRequest
var aux struct {
Alias
SystemInstructionSnake *GeminiChatContent `json:"system_instruction,omitempty"`
}
if err := common.Unmarshal(data, &aux); err != nil {
return err
}
*r = GeminiChatRequest(aux.Alias)
if aux.SystemInstructionSnake != nil {
r.SystemInstructions = aux.SystemInstructionSnake
}
return nil
}
type ToolConfig struct {
FunctionCallingConfig *FunctionCallingConfig `json:"functionCallingConfig,omitempty"`
RetrievalConfig *RetrievalConfig `json:"retrievalConfig,omitempty"`
@@ -105,7 +126,7 @@ func (r *GeminiChatRequest) SetModelName(modelName string) {
func (r *GeminiChatRequest) GetTools() []GeminiChatTool {
var tools []GeminiChatTool
if strings.HasSuffix(string(r.Tools), "[") {
if strings.HasPrefix(string(r.Tools), "[") {
// is array
if err := common.Unmarshal(r.Tools, &tools); err != nil {
logger.LogError(nil, "error_unmarshalling_tools: "+err.Error())

View File

@@ -167,9 +167,9 @@ func (i *ImageRequest) SetModelName(modelName string) {
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
Extra any `json:"extra,omitempty"`
Data []ImageData `json:"data"`
Created int64 `json:"created"`
Metadata json.RawMessage `json:"metadata,omitempty"`
}
type ImageData struct {
Url string `json:"url"`

View File

@@ -23,6 +23,8 @@ type FormatJsonSchema struct {
Strict json.RawMessage `json:"strict,omitempty"`
}
// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs.
// 参数增加规范无引用的参数必须使用json.RawMessage类型并添加omitempty标签
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
@@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct {
Reasoning json.RawMessage `json:"reasoning,omitempty"`
// Ali Qwen Params
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"`
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
EnableSearch json.RawMessage `json:"enable_search,omitempty"`
// ollama Params
Think json.RawMessage `json:"think,omitempty"`
// baidu v2

8
go.mod
View File

@@ -27,6 +27,7 @@ require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/grafana/pyroscope-go v1.2.7
github.com/jfreymuth/oggvorbis v1.0.5
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
@@ -36,6 +37,7 @@ require (
github.com/samber/lo v1.52.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.11.1
github.com/stripe/stripe-go/v81 v81.4.0
github.com/tcolgate/mp3 v0.0.0-20170426193717-e79c5a46d300
github.com/thanhpk/randstr v1.0.6
@@ -62,6 +64,7 @@ require (
github.com/bytedance/sonic/loader v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
@@ -77,11 +80,11 @@ require (
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-webauthn/x v0.1.25 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-tpm v0.9.5 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
github.com/icza/bitio v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@@ -91,6 +94,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@@ -101,7 +105,9 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect

48
go.sum
View File

@@ -118,9 +118,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/go-tpm v0.9.5 h1:ocUmnDebX54dnW+MQWGQRbdaAcJELsa6PqZhJ48KwVU=
github.com/google/go-tpm v0.9.5/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -132,6 +131,10 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grafana/pyroscope-go v1.2.7 h1:VWBBlqxjyR0Cwk2W6UrE8CdcdD80GOFNutj0Kb1T8ac=
github.com/grafana/pyroscope-go v1.2.7/go.mod h1:o/bpSLiJYYP6HQtvcoVKiE9s5RiNgjYTj1DhiddP2Pc=
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 h1:c1Us8i6eSmkW+Ez05d3co8kasnuOY813tbMN8i/a3Og=
github.com/grafana/pyroscope-go/godeltaprof v0.1.9/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU=
github.com/icza/bitio v1.1.0 h1:ysX4vtldjdi3Ygai5m1cWy4oLkhWTAi+SyO6HC8L9T0=
github.com/icza/bitio v1.1.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6 h1:8UsGZ2rr2ksmEru6lToqnXgA8Mz1DP11X4zSJ159C3k=
@@ -160,12 +163,15 @@ github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwA
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
@@ -214,14 +220,11 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA=
github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw=
github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
@@ -231,6 +234,7 @@ github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+D
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
@@ -288,12 +292,12 @@ golang.org/x/arch v0.21.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o=
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8=
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
@@ -321,6 +325,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
@@ -350,19 +356,29 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
modernc.org/cc/v4 v4.26.5 h1:xM3bX7Mve6G8K8b+T11ReenJOT+BmVqQj0FY5T4+5Y4=
modernc.org/cc/v4 v4.26.5/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.28.1 h1:wPKYn5EC/mYTqBO373jKjvX2n+3+aK7+sICCv4Fjy1A=
modernc.org/ccgo/v4 v4.28.1/go.mod h1:uD+4RnfrVgE6ec9NGguUNdhqzNIeeomeXf6CL0GTE5Q=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.66.10 h1:yZkb3YeLx4oynyR+iUsXsybsX4Ubx7MQlSYEw4yj59A=
modernc.org/libc v1.66.10/go.mod h1:8vGSEwvoUoltr4dlywvHqjtAqHBaw0j1jI7iFBTAr2I=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@@ -124,6 +124,11 @@ func main() {
common.SysLog("pprof enabled")
}
err = common.StartPyroScope()
if err != nil {
common.SysError(fmt.Sprintf("start pyroscope error : %v", err))
}
// Initialize HTTP server
server := gin.New()
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
@@ -183,6 +188,7 @@ func InjectUmamiAnalytics() {
analyticsInjectBuilder.WriteString(umamiSiteID)
analyticsInjectBuilder.WriteString("\"></script>")
}
analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
}
@@ -204,6 +210,7 @@ func InjectGoogleAnalytics() {
analyticsInjectBuilder.WriteString("');")
analyticsInjectBuilder.WriteString("</script>")
}
analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
}

View File

@@ -195,8 +195,8 @@ func TokenAuth() func(c *gin.Context) {
}
c.Request.Header.Set("Authorization", "Bearer "+key)
}
// 检查path包含/v1/messages
if strings.Contains(c.Request.URL.Path, "/v1/messages") {
// 检查path包含/v1/messages 或 /v1/models
if strings.Contains(c.Request.URL.Path, "/v1/messages") || strings.Contains(c.Request.URL.Path, "/v1/models") {
anthropicKey := c.Request.Header.Get("x-api-key")
if anthropicKey != "" {
c.Request.Header.Set("Authorization", "Bearer "+anthropicKey)
@@ -218,10 +218,14 @@ func TokenAuth() func(c *gin.Context) {
}
key := c.Request.Header.Get("Authorization")
parts := make([]string, 0)
key = strings.TrimPrefix(key, "Bearer ")
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
if key == "" || key == "midjourney-proxy" {
key = c.Request.Header.Get("mj-api-secret")
key = strings.TrimPrefix(key, "Bearer ")
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
key = strings.TrimPrefix(key, "sk-")
parts = strings.Split(key, "-")
key = parts[0]

179
model/checkin.go Normal file
View File

@@ -0,0 +1,179 @@
package model
import (
"errors"
"math/rand"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"gorm.io/gorm"
)
// Checkin 签到记录
type Checkin struct {
Id int `json:"id" gorm:"primaryKey;autoIncrement"`
UserId int `json:"user_id" gorm:"not null;uniqueIndex:idx_user_checkin_date"`
CheckinDate string `json:"checkin_date" gorm:"type:varchar(10);not null;uniqueIndex:idx_user_checkin_date"` // 格式: YYYY-MM-DD
QuotaAwarded int `json:"quota_awarded" gorm:"not null"`
CreatedAt int64 `json:"created_at" gorm:"bigint"`
}
// CheckinRecord 用于API返回的签到记录不包含敏感字段
type CheckinRecord struct {
CheckinDate string `json:"checkin_date"`
QuotaAwarded int `json:"quota_awarded"`
}
func (Checkin) TableName() string {
return "checkins"
}
// GetUserCheckinRecords 获取用户在指定日期范围内的签到记录
func GetUserCheckinRecords(userId int, startDate, endDate string) ([]Checkin, error) {
var records []Checkin
err := DB.Where("user_id = ? AND checkin_date >= ? AND checkin_date <= ?",
userId, startDate, endDate).
Order("checkin_date DESC").
Find(&records).Error
return records, err
}
// HasCheckedInToday 检查用户今天是否已签到
func HasCheckedInToday(userId int) (bool, error) {
today := time.Now().Format("2006-01-02")
var count int64
err := DB.Model(&Checkin{}).
Where("user_id = ? AND checkin_date = ?", userId, today).
Count(&count).Error
return count > 0, err
}
// UserCheckin 执行用户签到
// MySQL 和 PostgreSQL 使用事务保证原子性
// SQLite 不支持嵌套事务,使用顺序操作 + 手动回滚
func UserCheckin(userId int) (*Checkin, error) {
setting := operation_setting.GetCheckinSetting()
if !setting.Enabled {
return nil, errors.New("签到功能未启用")
}
// 检查今天是否已签到
hasChecked, err := HasCheckedInToday(userId)
if err != nil {
return nil, err
}
if hasChecked {
return nil, errors.New("今日已签到")
}
// 计算随机额度奖励
quotaAwarded := setting.MinQuota
if setting.MaxQuota > setting.MinQuota {
quotaAwarded = setting.MinQuota + rand.Intn(setting.MaxQuota-setting.MinQuota+1)
}
today := time.Now().Format("2006-01-02")
checkin := &Checkin{
UserId: userId,
CheckinDate: today,
QuotaAwarded: quotaAwarded,
CreatedAt: time.Now().Unix(),
}
// 根据数据库类型选择不同的策略
if common.UsingSQLite {
// SQLite 不支持嵌套事务,使用顺序操作 + 手动回滚
return userCheckinWithoutTransaction(checkin, userId, quotaAwarded)
}
// MySQL 和 PostgreSQL 支持事务,使用事务保证原子性
return userCheckinWithTransaction(checkin, userId, quotaAwarded)
}
// userCheckinWithTransaction 使用事务执行签到(适用于 MySQL 和 PostgreSQL
func userCheckinWithTransaction(checkin *Checkin, userId int, quotaAwarded int) (*Checkin, error) {
err := DB.Transaction(func(tx *gorm.DB) error {
// 步骤1: 创建签到记录
// 数据库有唯一约束 (user_id, checkin_date),可以防止并发重复签到
if err := tx.Create(checkin).Error; err != nil {
return errors.New("签到失败,请稍后重试")
}
// 步骤2: 在事务中增加用户额度
if err := tx.Model(&User{}).Where("id = ?", userId).
Update("quota", gorm.Expr("quota + ?", quotaAwarded)).Error; err != nil {
return errors.New("签到失败:更新额度出错")
}
return nil
})
if err != nil {
return nil, err
}
// 事务成功后,异步更新缓存
go func() {
_ = cacheIncrUserQuota(userId, int64(quotaAwarded))
}()
return checkin, nil
}
// userCheckinWithoutTransaction 不使用事务执行签到(适用于 SQLite
func userCheckinWithoutTransaction(checkin *Checkin, userId int, quotaAwarded int) (*Checkin, error) {
// 步骤1: 创建签到记录
// 数据库有唯一约束 (user_id, checkin_date),可以防止并发重复签到
if err := DB.Create(checkin).Error; err != nil {
return nil, errors.New("签到失败,请稍后重试")
}
// 步骤2: 增加用户额度
// 使用 db=true 强制直接写入数据库,不使用批量更新
if err := IncreaseUserQuota(userId, quotaAwarded, true); err != nil {
// 如果增加额度失败,需要回滚签到记录
DB.Delete(checkin)
return nil, errors.New("签到失败:更新额度出错")
}
return checkin, nil
}
// GetUserCheckinStats 获取用户签到统计信息
func GetUserCheckinStats(userId int, month string) (map[string]interface{}, error) {
// 获取指定月份的所有签到记录
startDate := month + "-01"
endDate := month + "-31"
records, err := GetUserCheckinRecords(userId, startDate, endDate)
if err != nil {
return nil, err
}
// 转换为不包含敏感字段的记录
checkinRecords := make([]CheckinRecord, len(records))
for i, r := range records {
checkinRecords[i] = CheckinRecord{
CheckinDate: r.CheckinDate,
QuotaAwarded: r.QuotaAwarded,
}
}
// 检查今天是否已签到
hasCheckedToday, _ := HasCheckedInToday(userId)
// 获取用户所有时间的签到统计
var totalCheckins int64
var totalQuota int64
DB.Model(&Checkin{}).Where("user_id = ?", userId).Count(&totalCheckins)
DB.Model(&Checkin{}).Where("user_id = ?", userId).Select("COALESCE(SUM(quota_awarded), 0)").Scan(&totalQuota)
return map[string]interface{}{
"total_quota": totalQuota, // 所有时间累计获得的额度
"total_checkins": totalCheckins, // 所有时间累计签到次数
"checkin_count": len(records), // 本月签到次数
"checked_in_today": hasCheckedToday, // 今天是否已签到
"records": checkinRecords, // 本月签到记录详情不含id和user_id
}, nil
}

View File

@@ -267,6 +267,7 @@ func migrateDB() error {
&Setup{},
&TwoFA{},
&TwoFABackupCode{},
&Checkin{},
)
if err != nil {
return err
@@ -300,6 +301,7 @@ func migrateDBFast() error {
{&Setup{}, "Setup"},
{&TwoFA{}, "TwoFA"},
{&TwoFABackupCode{}, "TwoFABackupCode"},
{&Checkin{}, "Checkin"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))

View File

@@ -26,7 +26,7 @@ type Token struct {
AllowIps *string `json:"allow_ips" gorm:"default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Group string `json:"group" gorm:"default:''"`
CrossGroupRetry bool `json:"cross_group_retry" gorm:"default:false"` // 跨分组重试仅auto分组有效
CrossGroupRetry bool `json:"cross_group_retry"` // 跨分组重试仅auto分组有效
DeletedAt gorm.DeletedAt `gorm:"index"`
}

219
pkg/ionet/client.go Normal file
View File

@@ -0,0 +1,219 @@
package ionet
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
)
const (
DefaultEnterpriseBaseURL = "https://api.io.solutions/enterprise/v1/io-cloud/caas"
DefaultBaseURL = "https://api.io.solutions/v1/io-cloud/caas"
DefaultTimeout = 30 * time.Second
)
// DefaultHTTPClient is the default HTTP client implementation
type DefaultHTTPClient struct {
client *http.Client
}
// NewDefaultHTTPClient creates a new default HTTP client
func NewDefaultHTTPClient(timeout time.Duration) *DefaultHTTPClient {
return &DefaultHTTPClient{
client: &http.Client{
Timeout: timeout,
},
}
}
// Do executes an HTTP request
func (c *DefaultHTTPClient) Do(req *HTTPRequest) (*HTTPResponse, error) {
httpReq, err := http.NewRequest(req.Method, req.URL, bytes.NewReader(req.Body))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
// Set headers
for key, value := range req.Headers {
httpReq.Header.Set(key, value)
}
resp, err := c.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("HTTP request failed: %w", err)
}
defer resp.Body.Close()
// Read response body
var body bytes.Buffer
_, err = body.ReadFrom(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Convert headers
headers := make(map[string]string)
for key, values := range resp.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
return &HTTPResponse{
StatusCode: resp.StatusCode,
Headers: headers,
Body: body.Bytes(),
}, nil
}
// NewEnterpriseClient creates a new IO.NET API client targeting the enterprise API base URL.
func NewEnterpriseClient(apiKey string) *Client {
return NewClientWithConfig(apiKey, DefaultEnterpriseBaseURL, nil)
}
// NewClient creates a new IO.NET API client targeting the public API base URL.
func NewClient(apiKey string) *Client {
return NewClientWithConfig(apiKey, DefaultBaseURL, nil)
}
// NewClientWithConfig creates a new IO.NET API client with custom configuration
func NewClientWithConfig(apiKey, baseURL string, httpClient HTTPClient) *Client {
if baseURL == "" {
baseURL = DefaultBaseURL
}
if httpClient == nil {
httpClient = NewDefaultHTTPClient(DefaultTimeout)
}
return &Client{
BaseURL: baseURL,
APIKey: apiKey,
HTTPClient: httpClient,
}
}
// makeRequest performs an HTTP request and handles common response processing
func (c *Client) makeRequest(method, endpoint string, body interface{}) (*HTTPResponse, error) {
var reqBody []byte
var err error
if body != nil {
reqBody, err = json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
}
headers := map[string]string{
"X-API-KEY": c.APIKey,
"Content-Type": "application/json",
}
req := &HTTPRequest{
Method: method,
URL: c.BaseURL + endpoint,
Headers: headers,
Body: reqBody,
}
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
// Handle API errors
if resp.StatusCode >= 400 {
var apiErr APIError
if len(resp.Body) > 0 {
// Try to parse the actual error format: {"detail": "message"}
var errorResp struct {
Detail string `json:"detail"`
}
if err := json.Unmarshal(resp.Body, &errorResp); err == nil && errorResp.Detail != "" {
apiErr = APIError{
Code: resp.StatusCode,
Message: errorResp.Detail,
}
} else {
// Fallback: use raw body as details
apiErr = APIError{
Code: resp.StatusCode,
Message: fmt.Sprintf("API request failed with status %d", resp.StatusCode),
Details: string(resp.Body),
}
}
} else {
apiErr = APIError{
Code: resp.StatusCode,
Message: fmt.Sprintf("API request failed with status %d", resp.StatusCode),
}
}
return nil, &apiErr
}
return resp, nil
}
// buildQueryParams builds query parameters for GET requests
func buildQueryParams(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
values := url.Values{}
for key, value := range params {
if value == nil {
continue
}
switch v := value.(type) {
case string:
if v != "" {
values.Add(key, v)
}
case int:
if v != 0 {
values.Add(key, strconv.Itoa(v))
}
case int64:
if v != 0 {
values.Add(key, strconv.FormatInt(v, 10))
}
case float64:
if v != 0 {
values.Add(key, strconv.FormatFloat(v, 'f', -1, 64))
}
case bool:
values.Add(key, strconv.FormatBool(v))
case time.Time:
if !v.IsZero() {
values.Add(key, v.Format(time.RFC3339))
}
case *time.Time:
if v != nil && !v.IsZero() {
values.Add(key, v.Format(time.RFC3339))
}
case []int:
if len(v) > 0 {
if encoded, err := json.Marshal(v); err == nil {
values.Add(key, string(encoded))
}
}
case []string:
if len(v) > 0 {
if encoded, err := json.Marshal(v); err == nil {
values.Add(key, string(encoded))
}
}
default:
values.Add(key, fmt.Sprint(v))
}
}
if len(values) > 0 {
return "?" + values.Encode()
}
return ""
}

302
pkg/ionet/container.go Normal file
View File

@@ -0,0 +1,302 @@
package ionet
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/samber/lo"
)
// ListContainers retrieves all containers for a specific deployment
func (c *Client) ListContainers(deploymentID string) (*ContainerList, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s/containers", deploymentID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
var containerList ContainerList
if err := decodeDataWithFlexibleTimes(resp.Body, &containerList); err != nil {
return nil, fmt.Errorf("failed to parse containers list: %w", err)
}
return &containerList, nil
}
// GetContainerDetails retrieves detailed information about a specific container
func (c *Client) GetContainerDetails(deploymentID, containerID string) (*Container, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return nil, fmt.Errorf("container ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s/container/%s", deploymentID, containerID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get container details: %w", err)
}
// API response format not documented, assuming direct format
var container Container
if err := decodeWithFlexibleTimes(resp.Body, &container); err != nil {
return nil, fmt.Errorf("failed to parse container details: %w", err)
}
return &container, nil
}
// GetContainerJobs retrieves containers jobs for a specific container (similar to containers endpoint)
func (c *Client) GetContainerJobs(deploymentID, containerID string) (*ContainerList, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return nil, fmt.Errorf("container ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s/containers-jobs/%s", deploymentID, containerID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get container jobs: %w", err)
}
var containerList ContainerList
if err := decodeDataWithFlexibleTimes(resp.Body, &containerList); err != nil {
return nil, fmt.Errorf("failed to parse container jobs: %w", err)
}
return &containerList, nil
}
// buildLogEndpoint constructs the request path for fetching logs
func buildLogEndpoint(deploymentID, containerID string, opts *GetLogsOptions) (string, error) {
if deploymentID == "" {
return "", fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return "", fmt.Errorf("container ID cannot be empty")
}
params := make(map[string]interface{})
if opts != nil {
if opts.Level != "" {
params["level"] = opts.Level
}
if opts.Stream != "" {
params["stream"] = opts.Stream
}
if opts.Limit > 0 {
params["limit"] = opts.Limit
}
if opts.Cursor != "" {
params["cursor"] = opts.Cursor
}
if opts.Follow {
params["follow"] = true
}
if opts.StartTime != nil {
params["start_time"] = opts.StartTime
}
if opts.EndTime != nil {
params["end_time"] = opts.EndTime
}
}
endpoint := fmt.Sprintf("/deployment/%s/log/%s", deploymentID, containerID)
endpoint += buildQueryParams(params)
return endpoint, nil
}
// GetContainerLogs retrieves logs for containers in a deployment and normalizes them
func (c *Client) GetContainerLogs(deploymentID, containerID string, opts *GetLogsOptions) (*ContainerLogs, error) {
raw, err := c.GetContainerLogsRaw(deploymentID, containerID, opts)
if err != nil {
return nil, err
}
logs := &ContainerLogs{
ContainerID: containerID,
}
if raw == "" {
return logs, nil
}
normalized := strings.ReplaceAll(raw, "\r\n", "\n")
lines := strings.Split(normalized, "\n")
logs.Logs = lo.FilterMap(lines, func(line string, _ int) (LogEntry, bool) {
if strings.TrimSpace(line) == "" {
return LogEntry{}, false
}
return LogEntry{Message: line}, true
})
return logs, nil
}
// GetContainerLogsRaw retrieves the raw text logs for a specific container
func (c *Client) GetContainerLogsRaw(deploymentID, containerID string, opts *GetLogsOptions) (string, error) {
endpoint, err := buildLogEndpoint(deploymentID, containerID, opts)
if err != nil {
return "", err
}
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return "", fmt.Errorf("failed to get container logs: %w", err)
}
return string(resp.Body), nil
}
// StreamContainerLogs streams real-time logs for a specific container
// This method uses a callback function to handle incoming log entries
func (c *Client) StreamContainerLogs(deploymentID, containerID string, opts *GetLogsOptions, callback func(*LogEntry) error) error {
if deploymentID == "" {
return fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return fmt.Errorf("container ID cannot be empty")
}
if callback == nil {
return fmt.Errorf("callback function cannot be nil")
}
// Set follow to true for streaming
if opts == nil {
opts = &GetLogsOptions{}
}
opts.Follow = true
endpoint, err := buildLogEndpoint(deploymentID, containerID, opts)
if err != nil {
return err
}
// Note: This is a simplified implementation. In a real scenario, you might want to use
// Server-Sent Events (SSE) or WebSocket for streaming logs
for {
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return fmt.Errorf("failed to stream container logs: %w", err)
}
var logs ContainerLogs
if err := decodeWithFlexibleTimes(resp.Body, &logs); err != nil {
return fmt.Errorf("failed to parse container logs: %w", err)
}
// Call the callback for each log entry
for _, logEntry := range logs.Logs {
if err := callback(&logEntry); err != nil {
return fmt.Errorf("callback error: %w", err)
}
}
// If there are no more logs or we have a cursor, continue polling
if !logs.HasMore && logs.NextCursor == "" {
break
}
// Update cursor for next request
if logs.NextCursor != "" {
opts.Cursor = logs.NextCursor
endpoint, err = buildLogEndpoint(deploymentID, containerID, opts)
if err != nil {
return err
}
}
// Wait a bit before next poll to avoid overwhelming the API
time.Sleep(2 * time.Second)
}
return nil
}
// RestartContainer restarts a specific container (if supported by the API)
func (c *Client) RestartContainer(deploymentID, containerID string) error {
if deploymentID == "" {
return fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return fmt.Errorf("container ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s/container/%s/restart", deploymentID, containerID)
_, err := c.makeRequest("POST", endpoint, nil)
if err != nil {
return fmt.Errorf("failed to restart container: %w", err)
}
return nil
}
// StopContainer stops a specific container (if supported by the API)
func (c *Client) StopContainer(deploymentID, containerID string) error {
if deploymentID == "" {
return fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return fmt.Errorf("container ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s/container/%s/stop", deploymentID, containerID)
_, err := c.makeRequest("POST", endpoint, nil)
if err != nil {
return fmt.Errorf("failed to stop container: %w", err)
}
return nil
}
// ExecuteInContainer executes a command in a specific container (if supported by the API)
func (c *Client) ExecuteInContainer(deploymentID, containerID string, command []string) (string, error) {
if deploymentID == "" {
return "", fmt.Errorf("deployment ID cannot be empty")
}
if containerID == "" {
return "", fmt.Errorf("container ID cannot be empty")
}
if len(command) == 0 {
return "", fmt.Errorf("command cannot be empty")
}
reqBody := map[string]interface{}{
"command": command,
}
endpoint := fmt.Sprintf("/deployment/%s/container/%s/exec", deploymentID, containerID)
resp, err := c.makeRequest("POST", endpoint, reqBody)
if err != nil {
return "", fmt.Errorf("failed to execute command in container: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(resp.Body, &result); err != nil {
return "", fmt.Errorf("failed to parse execution result: %w", err)
}
if output, ok := result["output"].(string); ok {
return output, nil
}
return string(resp.Body), nil
}

377
pkg/ionet/deployment.go Normal file
View File

@@ -0,0 +1,377 @@
package ionet
import (
"encoding/json"
"fmt"
"strings"
"github.com/samber/lo"
)
// DeployContainer deploys a new container with the specified configuration
func (c *Client) DeployContainer(req *DeploymentRequest) (*DeploymentResponse, error) {
if req == nil {
return nil, fmt.Errorf("deployment request cannot be nil")
}
// Validate required fields
if req.ResourcePrivateName == "" {
return nil, fmt.Errorf("resource_private_name is required")
}
if len(req.LocationIDs) == 0 {
return nil, fmt.Errorf("location_ids is required")
}
if req.HardwareID <= 0 {
return nil, fmt.Errorf("hardware_id is required")
}
if req.RegistryConfig.ImageURL == "" {
return nil, fmt.Errorf("registry_config.image_url is required")
}
if req.GPUsPerContainer < 1 {
return nil, fmt.Errorf("gpus_per_container must be at least 1")
}
if req.DurationHours < 1 {
return nil, fmt.Errorf("duration_hours must be at least 1")
}
if req.ContainerConfig.ReplicaCount < 1 {
return nil, fmt.Errorf("container_config.replica_count must be at least 1")
}
resp, err := c.makeRequest("POST", "/deploy", req)
if err != nil {
return nil, fmt.Errorf("failed to deploy container: %w", err)
}
// API returns direct format:
// {"status": "string", "deployment_id": "..."}
var deployResp DeploymentResponse
if err := json.Unmarshal(resp.Body, &deployResp); err != nil {
return nil, fmt.Errorf("failed to parse deployment response: %w", err)
}
return &deployResp, nil
}
// ListDeployments retrieves a list of deployments with optional filtering
func (c *Client) ListDeployments(opts *ListDeploymentsOptions) (*DeploymentList, error) {
params := make(map[string]interface{})
if opts != nil {
params["status"] = opts.Status
params["location_id"] = opts.LocationID
params["page"] = opts.Page
params["page_size"] = opts.PageSize
params["sort_by"] = opts.SortBy
params["sort_order"] = opts.SortOrder
}
endpoint := "/deployments" + buildQueryParams(params)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to list deployments: %w", err)
}
var deploymentList DeploymentList
if err := decodeData(resp.Body, &deploymentList); err != nil {
return nil, fmt.Errorf("failed to parse deployments list: %w", err)
}
deploymentList.Deployments = lo.Map(deploymentList.Deployments, func(deployment Deployment, _ int) Deployment {
deployment.GPUCount = deployment.HardwareQuantity
deployment.Replicas = deployment.HardwareQuantity // Assuming 1:1 mapping for now
return deployment
})
return &deploymentList, nil
}
// GetDeployment retrieves detailed information about a specific deployment
func (c *Client) GetDeployment(deploymentID string) (*DeploymentDetail, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s", deploymentID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get deployment details: %w", err)
}
var deploymentDetail DeploymentDetail
if err := decodeDataWithFlexibleTimes(resp.Body, &deploymentDetail); err != nil {
return nil, fmt.Errorf("failed to parse deployment details: %w", err)
}
return &deploymentDetail, nil
}
// UpdateDeployment updates the configuration of an existing deployment
func (c *Client) UpdateDeployment(deploymentID string, req *UpdateDeploymentRequest) (*UpdateDeploymentResponse, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
if req == nil {
return nil, fmt.Errorf("update request cannot be nil")
}
endpoint := fmt.Sprintf("/deployment/%s", deploymentID)
resp, err := c.makeRequest("PATCH", endpoint, req)
if err != nil {
return nil, fmt.Errorf("failed to update deployment: %w", err)
}
// API returns direct format:
// {"status": "string", "deployment_id": "..."}
var updateResp UpdateDeploymentResponse
if err := json.Unmarshal(resp.Body, &updateResp); err != nil {
return nil, fmt.Errorf("failed to parse update deployment response: %w", err)
}
return &updateResp, nil
}
// ExtendDeployment extends the duration of an existing deployment
func (c *Client) ExtendDeployment(deploymentID string, req *ExtendDurationRequest) (*DeploymentDetail, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
if req == nil {
return nil, fmt.Errorf("extend request cannot be nil")
}
if req.DurationHours < 1 {
return nil, fmt.Errorf("duration_hours must be at least 1")
}
endpoint := fmt.Sprintf("/deployment/%s/extend", deploymentID)
resp, err := c.makeRequest("POST", endpoint, req)
if err != nil {
return nil, fmt.Errorf("failed to extend deployment: %w", err)
}
var deploymentDetail DeploymentDetail
if err := decodeDataWithFlexibleTimes(resp.Body, &deploymentDetail); err != nil {
return nil, fmt.Errorf("failed to parse extended deployment details: %w", err)
}
return &deploymentDetail, nil
}
// DeleteDeployment deletes an active deployment
func (c *Client) DeleteDeployment(deploymentID string) (*UpdateDeploymentResponse, error) {
if deploymentID == "" {
return nil, fmt.Errorf("deployment ID cannot be empty")
}
endpoint := fmt.Sprintf("/deployment/%s", deploymentID)
resp, err := c.makeRequest("DELETE", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to delete deployment: %w", err)
}
// API returns direct format:
// {"status": "string", "deployment_id": "..."}
var deleteResp UpdateDeploymentResponse
if err := json.Unmarshal(resp.Body, &deleteResp); err != nil {
return nil, fmt.Errorf("failed to parse delete deployment response: %w", err)
}
return &deleteResp, nil
}
// GetPriceEstimation calculates the estimated cost for a deployment
func (c *Client) GetPriceEstimation(req *PriceEstimationRequest) (*PriceEstimationResponse, error) {
if req == nil {
return nil, fmt.Errorf("price estimation request cannot be nil")
}
// Validate required fields
if len(req.LocationIDs) == 0 {
return nil, fmt.Errorf("location_ids is required")
}
if req.HardwareID == 0 {
return nil, fmt.Errorf("hardware_id is required")
}
if req.ReplicaCount < 1 {
return nil, fmt.Errorf("replica_count must be at least 1")
}
currency := strings.TrimSpace(req.Currency)
if currency == "" {
currency = "usdc"
}
durationType := strings.TrimSpace(req.DurationType)
if durationType == "" {
durationType = "hour"
}
durationType = strings.ToLower(durationType)
apiDurationType := ""
durationQty := req.DurationQty
if durationQty < 1 {
durationQty = req.DurationHours
}
if durationQty < 1 {
return nil, fmt.Errorf("duration_qty must be at least 1")
}
hardwareQty := req.HardwareQty
if hardwareQty < 1 {
hardwareQty = req.GPUsPerContainer
}
if hardwareQty < 1 {
return nil, fmt.Errorf("hardware_qty must be at least 1")
}
durationHoursForRate := req.DurationHours
if durationHoursForRate < 1 {
durationHoursForRate = durationQty
}
switch durationType {
case "hour", "hours", "hourly":
durationHoursForRate = durationQty
apiDurationType = "hourly"
case "day", "days", "daily":
durationHoursForRate = durationQty * 24
apiDurationType = "daily"
case "week", "weeks", "weekly":
durationHoursForRate = durationQty * 24 * 7
apiDurationType = "weekly"
case "month", "months", "monthly":
durationHoursForRate = durationQty * 24 * 30
apiDurationType = "monthly"
}
if durationHoursForRate < 1 {
durationHoursForRate = 1
}
if apiDurationType == "" {
apiDurationType = "hourly"
}
params := map[string]interface{}{
"location_ids": req.LocationIDs,
"hardware_id": req.HardwareID,
"hardware_qty": hardwareQty,
"gpus_per_container": req.GPUsPerContainer,
"duration_type": apiDurationType,
"duration_qty": durationQty,
"duration_hours": req.DurationHours,
"replica_count": req.ReplicaCount,
"currency": currency,
}
endpoint := "/price" + buildQueryParams(params)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get price estimation: %w", err)
}
// Parse according to the actual API response format from docs:
// {
// "data": {
// "replica_count": 0,
// "gpus_per_container": 0,
// "available_replica_count": [0],
// "discount": 0,
// "ionet_fee": 0,
// "ionet_fee_percent": 0,
// "currency_conversion_fee": 0,
// "currency_conversion_fee_percent": 0,
// "total_cost_usdc": 0
// }
// }
var pricingData struct {
ReplicaCount int `json:"replica_count"`
GPUsPerContainer int `json:"gpus_per_container"`
AvailableReplicaCount []int `json:"available_replica_count"`
Discount float64 `json:"discount"`
IonetFee float64 `json:"ionet_fee"`
IonetFeePercent float64 `json:"ionet_fee_percent"`
CurrencyConversionFee float64 `json:"currency_conversion_fee"`
CurrencyConversionFeePercent float64 `json:"currency_conversion_fee_percent"`
TotalCostUSDC float64 `json:"total_cost_usdc"`
}
if err := decodeData(resp.Body, &pricingData); err != nil {
return nil, fmt.Errorf("failed to parse price estimation response: %w", err)
}
// Convert to our internal format
durationHoursFloat := float64(durationHoursForRate)
if durationHoursFloat <= 0 {
durationHoursFloat = 1
}
priceResp := &PriceEstimationResponse{
EstimatedCost: pricingData.TotalCostUSDC,
Currency: strings.ToUpper(currency),
EstimationValid: true,
PriceBreakdown: PriceBreakdown{
ComputeCost: pricingData.TotalCostUSDC - pricingData.IonetFee - pricingData.CurrencyConversionFee,
TotalCost: pricingData.TotalCostUSDC,
HourlyRate: pricingData.TotalCostUSDC / durationHoursFloat,
},
}
return priceResp, nil
}
// CheckClusterNameAvailability checks if a cluster name is available
func (c *Client) CheckClusterNameAvailability(clusterName string) (bool, error) {
if clusterName == "" {
return false, fmt.Errorf("cluster name cannot be empty")
}
params := map[string]interface{}{
"cluster_name": clusterName,
}
endpoint := "/clusters/check_cluster_name_availability" + buildQueryParams(params)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return false, fmt.Errorf("failed to check cluster name availability: %w", err)
}
var availabilityResp bool
if err := json.Unmarshal(resp.Body, &availabilityResp); err != nil {
return false, fmt.Errorf("failed to parse cluster name availability response: %w", err)
}
return availabilityResp, nil
}
// UpdateClusterName updates the name of an existing cluster/deployment
func (c *Client) UpdateClusterName(clusterID string, req *UpdateClusterNameRequest) (*UpdateClusterNameResponse, error) {
if clusterID == "" {
return nil, fmt.Errorf("cluster ID cannot be empty")
}
if req == nil {
return nil, fmt.Errorf("update cluster name request cannot be nil")
}
if req.Name == "" {
return nil, fmt.Errorf("cluster name cannot be empty")
}
endpoint := fmt.Sprintf("/clusters/%s/update-name", clusterID)
resp, err := c.makeRequest("PUT", endpoint, req)
if err != nil {
return nil, fmt.Errorf("failed to update cluster name: %w", err)
}
// Parse the response directly without data wrapper based on API docs
var updateResp UpdateClusterNameResponse
if err := json.Unmarshal(resp.Body, &updateResp); err != nil {
return nil, fmt.Errorf("failed to parse update cluster name response: %w", err)
}
return &updateResp, nil
}

202
pkg/ionet/hardware.go Normal file
View File

@@ -0,0 +1,202 @@
package ionet
import (
"encoding/json"
"fmt"
"strings"
"github.com/samber/lo"
)
// GetAvailableReplicas retrieves available replicas per location for specified hardware
func (c *Client) GetAvailableReplicas(hardwareID int, gpuCount int) (*AvailableReplicasResponse, error) {
if hardwareID <= 0 {
return nil, fmt.Errorf("hardware_id must be greater than 0")
}
if gpuCount < 1 {
return nil, fmt.Errorf("gpu_count must be at least 1")
}
params := map[string]interface{}{
"hardware_id": hardwareID,
"hardware_qty": gpuCount,
}
endpoint := "/available-replicas" + buildQueryParams(params)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get available replicas: %w", err)
}
type availableReplicaPayload struct {
ID int `json:"id"`
ISO2 string `json:"iso2"`
Name string `json:"name"`
AvailableReplicas int `json:"available_replicas"`
}
var payload []availableReplicaPayload
if err := decodeData(resp.Body, &payload); err != nil {
return nil, fmt.Errorf("failed to parse available replicas response: %w", err)
}
replicas := lo.Map(payload, func(item availableReplicaPayload, _ int) AvailableReplica {
return AvailableReplica{
LocationID: item.ID,
LocationName: item.Name,
HardwareID: hardwareID,
HardwareName: "",
AvailableCount: item.AvailableReplicas,
MaxGPUs: gpuCount,
}
})
return &AvailableReplicasResponse{Replicas: replicas}, nil
}
// GetMaxGPUsPerContainer retrieves the maximum number of GPUs available per hardware type
func (c *Client) GetMaxGPUsPerContainer() (*MaxGPUResponse, error) {
resp, err := c.makeRequest("GET", "/hardware/max-gpus-per-container", nil)
if err != nil {
return nil, fmt.Errorf("failed to get max GPUs per container: %w", err)
}
var maxGPUResp MaxGPUResponse
if err := decodeData(resp.Body, &maxGPUResp); err != nil {
return nil, fmt.Errorf("failed to parse max GPU response: %w", err)
}
return &maxGPUResp, nil
}
// ListHardwareTypes retrieves available hardware types using the max GPUs endpoint
func (c *Client) ListHardwareTypes() ([]HardwareType, int, error) {
maxGPUResp, err := c.GetMaxGPUsPerContainer()
if err != nil {
return nil, 0, fmt.Errorf("failed to list hardware types: %w", err)
}
mapped := lo.Map(maxGPUResp.Hardware, func(hw MaxGPUInfo, _ int) HardwareType {
name := strings.TrimSpace(hw.HardwareName)
if name == "" {
name = fmt.Sprintf("Hardware %d", hw.HardwareID)
}
return HardwareType{
ID: hw.HardwareID,
Name: name,
GPUType: "",
GPUMemory: 0,
MaxGPUs: hw.MaxGPUsPerContainer,
CPU: "",
Memory: 0,
Storage: 0,
HourlyRate: 0,
Available: hw.Available > 0,
BrandName: strings.TrimSpace(hw.BrandName),
AvailableCount: hw.Available,
}
})
totalAvailable := maxGPUResp.Total
if totalAvailable == 0 {
totalAvailable = lo.SumBy(maxGPUResp.Hardware, func(hw MaxGPUInfo) int {
return hw.Available
})
}
return mapped, totalAvailable, nil
}
// ListLocations retrieves available deployment locations (if supported by the API)
func (c *Client) ListLocations() (*LocationsResponse, error) {
resp, err := c.makeRequest("GET", "/locations", nil)
if err != nil {
return nil, fmt.Errorf("failed to list locations: %w", err)
}
var locations LocationsResponse
if err := decodeData(resp.Body, &locations); err != nil {
return nil, fmt.Errorf("failed to parse locations response: %w", err)
}
locations.Locations = lo.Map(locations.Locations, func(location Location, _ int) Location {
location.ISO2 = strings.ToUpper(strings.TrimSpace(location.ISO2))
return location
})
if locations.Total == 0 {
locations.Total = lo.SumBy(locations.Locations, func(location Location) int {
return location.Available
})
}
return &locations, nil
}
// GetHardwareType retrieves details about a specific hardware type
func (c *Client) GetHardwareType(hardwareID int) (*HardwareType, error) {
if hardwareID <= 0 {
return nil, fmt.Errorf("hardware ID must be greater than 0")
}
endpoint := fmt.Sprintf("/hardware/types/%d", hardwareID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get hardware type: %w", err)
}
// API response format not documented, assuming direct format
var hardwareType HardwareType
if err := json.Unmarshal(resp.Body, &hardwareType); err != nil {
return nil, fmt.Errorf("failed to parse hardware type: %w", err)
}
return &hardwareType, nil
}
// GetLocation retrieves details about a specific location
func (c *Client) GetLocation(locationID int) (*Location, error) {
if locationID <= 0 {
return nil, fmt.Errorf("location ID must be greater than 0")
}
endpoint := fmt.Sprintf("/locations/%d", locationID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get location: %w", err)
}
// API response format not documented, assuming direct format
var location Location
if err := json.Unmarshal(resp.Body, &location); err != nil {
return nil, fmt.Errorf("failed to parse location: %w", err)
}
return &location, nil
}
// GetLocationAvailability retrieves real-time availability for a specific location
func (c *Client) GetLocationAvailability(locationID int) (*LocationAvailability, error) {
if locationID <= 0 {
return nil, fmt.Errorf("location ID must be greater than 0")
}
endpoint := fmt.Sprintf("/locations/%d/availability", locationID)
resp, err := c.makeRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to get location availability: %w", err)
}
// API response format not documented, assuming direct format
var availability LocationAvailability
if err := json.Unmarshal(resp.Body, &availability); err != nil {
return nil, fmt.Errorf("failed to parse location availability: %w", err)
}
return &availability, nil
}

96
pkg/ionet/jsonutil.go Normal file
View File

@@ -0,0 +1,96 @@
package ionet
import (
"encoding/json"
"strings"
"time"
"github.com/samber/lo"
)
// decodeWithFlexibleTimes unmarshals API responses while tolerating timestamp strings
// that omit timezone information by normalizing them to RFC3339Nano.
func decodeWithFlexibleTimes(data []byte, target interface{}) error {
var intermediate interface{}
if err := json.Unmarshal(data, &intermediate); err != nil {
return err
}
normalized := normalizeTimeValues(intermediate)
reencoded, err := json.Marshal(normalized)
if err != nil {
return err
}
return json.Unmarshal(reencoded, target)
}
func decodeData[T any](data []byte, target *T) error {
var wrapper struct {
Data T `json:"data"`
}
if err := json.Unmarshal(data, &wrapper); err != nil {
return err
}
*target = wrapper.Data
return nil
}
func decodeDataWithFlexibleTimes[T any](data []byte, target *T) error {
var wrapper struct {
Data T `json:"data"`
}
if err := decodeWithFlexibleTimes(data, &wrapper); err != nil {
return err
}
*target = wrapper.Data
return nil
}
func normalizeTimeValues(value interface{}) interface{} {
switch v := value.(type) {
case map[string]interface{}:
return lo.MapValues(v, func(val interface{}, _ string) interface{} {
return normalizeTimeValues(val)
})
case []interface{}:
return lo.Map(v, func(item interface{}, _ int) interface{} {
return normalizeTimeValues(item)
})
case string:
if normalized, changed := normalizeTimeString(v); changed {
return normalized
}
return v
default:
return value
}
}
func normalizeTimeString(input string) (string, bool) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
return input, false
}
if _, err := time.Parse(time.RFC3339Nano, trimmed); err == nil {
return trimmed, trimmed != input
}
if _, err := time.Parse(time.RFC3339, trimmed); err == nil {
return trimmed, trimmed != input
}
layouts := []string{
"2006-01-02T15:04:05.999999999",
"2006-01-02T15:04:05.999999",
"2006-01-02T15:04:05",
}
for _, layout := range layouts {
if parsed, err := time.Parse(layout, trimmed); err == nil {
return parsed.UTC().Format(time.RFC3339Nano), true
}
}
return input, false
}

353
pkg/ionet/types.go Normal file
View File

@@ -0,0 +1,353 @@
package ionet
import (
"time"
)
// Client represents the IO.NET API client
type Client struct {
BaseURL string
APIKey string
HTTPClient HTTPClient
}
// HTTPClient interface for making HTTP requests
type HTTPClient interface {
Do(req *HTTPRequest) (*HTTPResponse, error)
}
// HTTPRequest represents an HTTP request
type HTTPRequest struct {
Method string
URL string
Headers map[string]string
Body []byte
}
// HTTPResponse represents an HTTP response
type HTTPResponse struct {
StatusCode int
Headers map[string]string
Body []byte
}
// DeploymentRequest represents a container deployment request
type DeploymentRequest struct {
ResourcePrivateName string `json:"resource_private_name"`
DurationHours int `json:"duration_hours"`
GPUsPerContainer int `json:"gpus_per_container"`
HardwareID int `json:"hardware_id"`
LocationIDs []int `json:"location_ids"`
ContainerConfig ContainerConfig `json:"container_config"`
RegistryConfig RegistryConfig `json:"registry_config"`
}
// ContainerConfig represents container configuration
type ContainerConfig struct {
ReplicaCount int `json:"replica_count"`
EnvVariables map[string]string `json:"env_variables,omitempty"`
SecretEnvVariables map[string]string `json:"secret_env_variables,omitempty"`
Entrypoint []string `json:"entrypoint,omitempty"`
TrafficPort int `json:"traffic_port,omitempty"`
Args []string `json:"args,omitempty"`
}
// RegistryConfig represents registry configuration
type RegistryConfig struct {
ImageURL string `json:"image_url"`
RegistryUsername string `json:"registry_username,omitempty"`
RegistrySecret string `json:"registry_secret,omitempty"`
}
// DeploymentResponse represents the response from deployment creation
type DeploymentResponse struct {
DeploymentID string `json:"deployment_id"`
Status string `json:"status"`
}
// DeploymentDetail represents detailed deployment information
type DeploymentDetail struct {
ID string `json:"id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
FinishedAt *time.Time `json:"finished_at,omitempty"`
AmountPaid float64 `json:"amount_paid"`
CompletedPercent float64 `json:"completed_percent"`
TotalGPUs int `json:"total_gpus"`
GPUsPerContainer int `json:"gpus_per_container"`
TotalContainers int `json:"total_containers"`
HardwareName string `json:"hardware_name"`
HardwareID int `json:"hardware_id"`
Locations []DeploymentLocation `json:"locations"`
BrandName string `json:"brand_name"`
ComputeMinutesServed int `json:"compute_minutes_served"`
ComputeMinutesRemaining int `json:"compute_minutes_remaining"`
ContainerConfig DeploymentContainerConfig `json:"container_config"`
}
// DeploymentLocation represents a location in deployment details
type DeploymentLocation struct {
ID int `json:"id"`
ISO2 string `json:"iso2"`
Name string `json:"name"`
}
// DeploymentContainerConfig represents container config in deployment details
type DeploymentContainerConfig struct {
Entrypoint []string `json:"entrypoint"`
EnvVariables map[string]interface{} `json:"env_variables"`
TrafficPort int `json:"traffic_port"`
ImageURL string `json:"image_url"`
}
// Container represents a container within a deployment
type Container struct {
DeviceID string `json:"device_id"`
ContainerID string `json:"container_id"`
Hardware string `json:"hardware"`
BrandName string `json:"brand_name"`
CreatedAt time.Time `json:"created_at"`
UptimePercent int `json:"uptime_percent"`
GPUsPerContainer int `json:"gpus_per_container"`
Status string `json:"status"`
ContainerEvents []ContainerEvent `json:"container_events"`
PublicURL string `json:"public_url"`
}
// ContainerEvent represents a container event
type ContainerEvent struct {
Time time.Time `json:"time"`
Message string `json:"message"`
}
// ContainerList represents a list of containers
type ContainerList struct {
Total int `json:"total"`
Workers []Container `json:"workers"`
}
// Deployment represents a deployment in the list
type Deployment struct {
ID string `json:"id"`
Status string `json:"status"`
Name string `json:"name"`
CompletedPercent float64 `json:"completed_percent"`
HardwareQuantity int `json:"hardware_quantity"`
BrandName string `json:"brand_name"`
HardwareName string `json:"hardware_name"`
Served string `json:"served"`
Remaining string `json:"remaining"`
ComputeMinutesServed int `json:"compute_minutes_served"`
ComputeMinutesRemaining int `json:"compute_minutes_remaining"`
CreatedAt time.Time `json:"created_at"`
GPUCount int `json:"-"` // Derived from HardwareQuantity
Replicas int `json:"-"` // Derived from HardwareQuantity
}
// DeploymentList represents a list of deployments with pagination
type DeploymentList struct {
Deployments []Deployment `json:"deployments"`
Total int `json:"total"`
Statuses []string `json:"statuses"`
}
// AvailableReplica represents replica availability for a location
type AvailableReplica struct {
LocationID int `json:"location_id"`
LocationName string `json:"location_name"`
HardwareID int `json:"hardware_id"`
HardwareName string `json:"hardware_name"`
AvailableCount int `json:"available_count"`
MaxGPUs int `json:"max_gpus"`
}
// AvailableReplicasResponse represents the response for available replicas
type AvailableReplicasResponse struct {
Replicas []AvailableReplica `json:"replicas"`
}
// MaxGPUResponse represents the response for maximum GPUs per container
type MaxGPUResponse struct {
Hardware []MaxGPUInfo `json:"hardware"`
Total int `json:"total"`
}
// MaxGPUInfo represents max GPU information for a hardware type
type MaxGPUInfo struct {
MaxGPUsPerContainer int `json:"max_gpus_per_container"`
Available int `json:"available"`
HardwareID int `json:"hardware_id"`
HardwareName string `json:"hardware_name"`
BrandName string `json:"brand_name"`
}
// PriceEstimationRequest represents a price estimation request
type PriceEstimationRequest struct {
LocationIDs []int `json:"location_ids"`
HardwareID int `json:"hardware_id"`
GPUsPerContainer int `json:"gpus_per_container"`
DurationHours int `json:"duration_hours"`
ReplicaCount int `json:"replica_count"`
Currency string `json:"currency"`
DurationType string `json:"duration_type"`
DurationQty int `json:"duration_qty"`
HardwareQty int `json:"hardware_qty"`
}
// PriceEstimationResponse represents the price estimation response
type PriceEstimationResponse struct {
EstimatedCost float64 `json:"estimated_cost"`
Currency string `json:"currency"`
PriceBreakdown PriceBreakdown `json:"price_breakdown"`
EstimationValid bool `json:"estimation_valid"`
}
// PriceBreakdown represents detailed cost breakdown
type PriceBreakdown struct {
ComputeCost float64 `json:"compute_cost"`
NetworkCost float64 `json:"network_cost,omitempty"`
StorageCost float64 `json:"storage_cost,omitempty"`
TotalCost float64 `json:"total_cost"`
HourlyRate float64 `json:"hourly_rate"`
}
// ContainerLogs represents container log entries
type ContainerLogs struct {
ContainerID string `json:"container_id"`
Logs []LogEntry `json:"logs"`
HasMore bool `json:"has_more"`
NextCursor string `json:"next_cursor,omitempty"`
}
// LogEntry represents a single log entry
type LogEntry struct {
Timestamp time.Time `json:"timestamp"`
Level string `json:"level,omitempty"`
Message string `json:"message"`
Source string `json:"source,omitempty"`
}
// UpdateDeploymentRequest represents request to update deployment configuration
type UpdateDeploymentRequest struct {
EnvVariables map[string]string `json:"env_variables,omitempty"`
SecretEnvVariables map[string]string `json:"secret_env_variables,omitempty"`
Entrypoint []string `json:"entrypoint,omitempty"`
TrafficPort *int `json:"traffic_port,omitempty"`
ImageURL string `json:"image_url,omitempty"`
RegistryUsername string `json:"registry_username,omitempty"`
RegistrySecret string `json:"registry_secret,omitempty"`
Args []string `json:"args,omitempty"`
Command string `json:"command,omitempty"`
}
// ExtendDurationRequest represents request to extend deployment duration
type ExtendDurationRequest struct {
DurationHours int `json:"duration_hours"`
}
// UpdateDeploymentResponse represents response from deployment update
type UpdateDeploymentResponse struct {
Status string `json:"status"`
DeploymentID string `json:"deployment_id"`
}
// UpdateClusterNameRequest represents request to update cluster name
type UpdateClusterNameRequest struct {
Name string `json:"cluster_name"`
}
// UpdateClusterNameResponse represents response from cluster name update
type UpdateClusterNameResponse struct {
Status string `json:"status"`
Message string `json:"message"`
}
// APIError represents an API error response
type APIError struct {
Code int `json:"code"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
}
// Error implements the error interface
func (e *APIError) Error() string {
if e.Details != "" {
return e.Message + ": " + e.Details
}
return e.Message
}
// ListDeploymentsOptions represents options for listing deployments
type ListDeploymentsOptions struct {
Status string `json:"status,omitempty"` // filter by status
LocationID int `json:"location_id,omitempty"` // filter by location
Page int `json:"page,omitempty"` // pagination
PageSize int `json:"page_size,omitempty"` // pagination
SortBy string `json:"sort_by,omitempty"` // sort field
SortOrder string `json:"sort_order,omitempty"` // asc/desc
}
// GetLogsOptions represents options for retrieving container logs
type GetLogsOptions struct {
StartTime *time.Time `json:"start_time,omitempty"`
EndTime *time.Time `json:"end_time,omitempty"`
Level string `json:"level,omitempty"` // filter by log level
Stream string `json:"stream,omitempty"` // filter by stdout/stderr streams
Limit int `json:"limit,omitempty"` // max number of log entries
Cursor string `json:"cursor,omitempty"` // pagination cursor
Follow bool `json:"follow,omitempty"` // stream logs
}
// HardwareType represents a hardware type available for deployment
type HardwareType struct {
ID int `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
GPUType string `json:"gpu_type"`
GPUMemory int `json:"gpu_memory"` // in GB
MaxGPUs int `json:"max_gpus"`
CPU string `json:"cpu,omitempty"`
Memory int `json:"memory,omitempty"` // in GB
Storage int `json:"storage,omitempty"` // in GB
HourlyRate float64 `json:"hourly_rate"`
Available bool `json:"available"`
BrandName string `json:"brand_name,omitempty"`
AvailableCount int `json:"available_count,omitempty"`
}
// Location represents a deployment location
type Location struct {
ID int `json:"id"`
Name string `json:"name"`
ISO2 string `json:"iso2,omitempty"`
Region string `json:"region,omitempty"`
Country string `json:"country,omitempty"`
Latitude float64 `json:"latitude,omitempty"`
Longitude float64 `json:"longitude,omitempty"`
Available int `json:"available,omitempty"`
Description string `json:"description,omitempty"`
}
// LocationsResponse represents the list of locations and aggregated metadata.
type LocationsResponse struct {
Locations []Location `json:"locations"`
Total int `json:"total"`
}
// LocationAvailability represents real-time availability for a location
type LocationAvailability struct {
LocationID int `json:"location_id"`
LocationName string `json:"location_name"`
Available bool `json:"available"`
HardwareAvailability []HardwareAvailability `json:"hardware_availability"`
UpdatedAt time.Time `json:"updated_at"`
}
// HardwareAvailability represents availability for specific hardware at a location
type HardwareAvailability struct {
HardwareID int `json:"hardware_id"`
HardwareName string `json:"hardware_name"`
AvailableCount int `json:"available_count"`
MaxGPUs int `json:"max_gpus"`
}

View File

@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
}
return nil

View File

@@ -19,6 +19,22 @@ import (
)
type Adaptor struct {
IsSyncImageModel bool
}
var syncModels = []string{
"z-image",
"qwen-image",
"wan2.6",
}
func isSyncImageModel(modelName string) bool {
for _, m := range syncModels {
if strings.Contains(modelName, m) {
return true
}
}
return false
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
if isSyncImageModel(info.OriginModelName) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
} else {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
}
case constant.RelayModeImagesEdits:
if isWanModel(info.OriginModelName) {
if isOldWanModel(info.OriginModelName) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
} else if isWanModel(info.OriginModelName) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
} else {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
}
@@ -72,7 +94,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
if info.RelayMode == constant.RelayModeImagesGenerations {
req.Set("X-DashScope-Async", "enable")
if isSyncImageModel(info.OriginModelName) {
} else {
req.Set("X-DashScope-Async", "enable")
}
}
if info.RelayMode == constant.RelayModeImagesEdits {
if isWanModel(info.OriginModelName) {
@@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
if info.RelayMode == constant.RelayModeImagesGenerations {
aliRequest, err := oaiImage2Ali(request)
if isSyncImageModel(info.OriginModelName) {
a.IsSyncImageModel = true
}
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
if err != nil {
return nil, fmt.Errorf("convert image request failed: %w", err)
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
}
return aliRequest, nil
} else if info.RelayMode == constant.RelayModeImagesEdits {
if isWanModel(info.OriginModelName) {
if isOldWanModel(info.OriginModelName) {
return oaiFormEdit2WanxImageEdit(c, info, request)
}
if isSyncImageModel(info.OriginModelName) {
if isWanModel(info.OriginModelName) {
a.IsSyncImageModel = false
} else {
a.IsSyncImageModel = true
}
}
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
// 如果用户使用表单,则需要解析表单数据
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
@@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
return aliRequest, nil
} else {
aliRequest, err := oaiImage2Ali(request)
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
if err != nil {
return nil, fmt.Errorf("convert image request failed: %w", err)
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
}
return aliRequest, nil
}
@@ -150,7 +186,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
//TODO implement me
return nil, errors.New("not implemented")
}
@@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
default:
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
err, usage = aliImageHandler(a, c, resp, info)
case constant.RelayModeImagesEdits:
if isWanModel(info.OriginModelName) {
err, usage = aliImageHandler(c, resp, info)
} else {
err, usage = aliImageEditHandler(c, resp, info)
}
err, usage = aliImageHandler(a, c, resp, info)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default:

View File

@@ -1,6 +1,13 @@
package ali
import "github.com/QuantumNous/new-api/dto"
import (
"strings"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
type AliMessage struct {
Content any `json:"content"`
@@ -65,6 +72,7 @@ type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
ImageCount int `json:"image_count,omitempty"`
}
type TaskResult struct {
@@ -75,14 +83,78 @@ type TaskResult struct {
}
type AliOutput struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
Choices []map[string]any `json:"choices,omitempty"`
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
Choices []struct {
FinishReason string `json:"finish_reason,omitempty"`
Message struct {
Role string `json:"role,omitempty"`
Content []AliMediaContent `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
} `json:"message,omitempty"`
} `json:"choices,omitempty"`
}
func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
var imageData []dto.ImageData
if len(o.Choices) > 0 {
for _, choice := range o.Choices {
var data dto.ImageData
for _, content := range choice.Message.Content {
if content.Image != "" {
if strings.HasPrefix(content.Image, "http") {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(content.Image)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
}
data.Url = content.Image
data.B64Json = b64Json
} else {
data.B64Json = content.Image
}
} else if content.Text != "" {
data.RevisedPrompt = content.Text
}
}
imageData = append(imageData, data)
}
}
return imageData
}
func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
var imageData []dto.ImageData
for _, data := range o.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageData = append(imageData, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return imageData
}
type AliResponse struct {
@@ -92,18 +164,26 @@ type AliResponse struct {
}
type AliImageRequest struct {
Model string `json:"model"`
Input any `json:"input"`
Parameters any `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Model string `json:"model"`
Input any `json:"input"`
Parameters AliImageParameters `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type AliImageParameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
PromptExtend *bool `json:"prompt_extend,omitempty"`
}
func (p *AliImageParameters) PromptExtendValue() bool {
if p != nil && p.PromptExtend != nil {
return *p.PromptExtend
}
return false
}
type AliImageInput struct {

View File

@@ -1,7 +1,6 @@
package ali
import (
"context"
"encoding/base64"
"errors"
"fmt"
@@ -21,17 +20,23 @@ import (
"github.com/gin-gonic/gin"
)
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
var imageRequest AliImageRequest
imageRequest.Model = request.Model
imageRequest.ResponseFormat = request.ResponseFormat
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
if request.Extra != nil {
if val, ok := request.Extra["parameters"]; ok {
err := common.Unmarshal(val, &imageRequest.Parameters)
if err != nil {
return nil, fmt.Errorf("invalid parameters field: %w", err)
}
} else {
// 兼容没有parameters字段的情况从openai标准字段中提取参数
imageRequest.Parameters = AliImageParameters{
Size: strings.Replace(request.Size, "x", "*", -1),
N: int(request.N),
Watermark: request.Watermark,
}
}
if val, ok := request.Extra["input"]; ok {
err := common.Unmarshal(val, &imageRequest.Input)
@@ -41,23 +46,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
}
}
if imageRequest.Parameters == nil {
imageRequest.Parameters = AliImageParameters{
Size: strings.Replace(request.Size, "x", "*", -1),
N: int(request.N),
Watermark: request.Watermark,
if strings.Contains(request.Model, "z-image") {
// z-image 开启prompt_extend后按2倍计费
if imageRequest.Parameters.PromptExtendValue() {
info.PriceData.AddOtherRatio("prompt_extend", 2)
}
}
if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{
Prompt: request.Prompt,
// 检查n参数
if imageRequest.Parameters.N != 0 {
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
}
// 同步图片模型和异步图片模型请求格式不一样
if isSync {
if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{
Messages: []AliMessage{
{
Role: "user",
Content: []AliMediaContent{
{
Text: request.Prompt,
},
},
},
},
}
}
} else {
if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{
Prompt: request.Prompt,
}
}
}
return &imageRequest, nil
}
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
mf := c.Request.MultipartForm
if mf == nil {
@@ -199,6 +225,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
var taskResponse AliResponse
var responseBody []byte
time.Sleep(time.Duration(5) * time.Second)
for {
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
step++
@@ -238,32 +266,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
if len(response.Output.Results) > 0 {
imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
} else if len(response.Output.Choices) > 0 {
imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
}
var mapResponse map[string]any
_ = common.Unmarshal(originBody, &mapResponse)
imageResponse.Extra = mapResponse
imageResponse.Metadata = originBody
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
@@ -282,66 +295,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponse), nil
}
var (
aliResponse *AliResponse
originRespBody []byte
)
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return nil, &dto.Usage{}
}
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Message != "" {
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
}
var fullTextResponse dto.ImageResponse
if len(aliResponse.Output.Choices) > 0 {
fullTextResponse = dto.ImageResponse{
Created: info.StartTime.Unix(),
Data: []dto.ImageData{
{
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
B64Json: "",
},
},
if a.IsSyncImageModel {
aliResponse = &aliTaskResponse
originRespBody = responseBody
} else {
// 异步图片模型需要轮询任务结果
aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponse), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
}, resp.StatusCode), nil
}
}
var mapResponse map[string]any
_ = common.Unmarshal(responseBody, &mapResponse)
fullTextResponse.Extra = mapResponse
jsonResponse, err := common.Marshal(fullTextResponse)
//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
if a.IsSyncImageModel {
logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
} else {
logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
}
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
// 可能生成多张图片修正计费数量n
if aliResponse.Usage.ImageCount != 0 {
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
} else if len(imageResponses.Data) != 0 {
info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
}
jsonResponse, err := common.Marshal(imageResponses)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
service.IOCopyBytesGracefully(c, resp, jsonResponse)
return nil, &dto.Usage{}
}

View File

@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
}
wanParams := WanImageParameters{
//wanParams := WanImageParameters{
// N: int(request.N),
//}
imageRequest.Input = wanInput
imageRequest.Parameters = AliImageParameters{
N: int(request.N),
}
imageRequest.Input = wanInput
imageRequest.Parameters = wanParams
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
return &imageRequest, nil
}
func isOldWanModel(modelName string) bool {
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
}
func isWanModel(modelName string) bool {
return strings.Contains(modelName, "wan")
}

View File

@@ -1,11 +1,13 @@
package aws
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
@@ -37,6 +39,13 @@ func getAwsErrorStatusCode(err error) int {
return http.StatusInternalServerError
}
func newAwsInvokeContext() (context.Context, context.CancelFunc) {
if common.RelayTimeout <= 0 {
return context.Background(), func() {}
}
return context.WithTimeout(context.Background(), time.Duration(common.RelayTimeout)*time.Second)
}
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
var (
httpClient *http.Client
@@ -117,6 +126,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
return nil, types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody)
}
awsReq.Body = reqBody
a.AwsReq = awsReq
return nil, nil
} else {
awsClaudeReq, err := formatRequest(requestBody, requestHeader)
@@ -201,7 +211,10 @@ func getAwsModelID(requestModel string) string {
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
ctx, cancel := newAwsInvokeContext()
defer cancel()
awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
if err != nil {
statusCode := getAwsErrorStatusCode(err)
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil
@@ -228,7 +241,10 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
}
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
ctx, cancel := newAwsInvokeContext()
defer cancel()
awsResp, err := a.AwsClient.InvokeModelWithResponseStream(ctx, a.AwsReq.(*bedrockruntime.InvokeModelWithResponseStreamInput))
if err != nil {
statusCode := getAwsErrorStatusCode(err)
return types.NewOpenAIError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeAwsInvokeError, statusCode), nil
@@ -268,7 +284,10 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
// Nova模型处理函数
func handleNovaRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types.NewAPIError, *dto.Usage) {
awsResp, err := a.AwsClient.InvokeModel(c.Request.Context(), a.AwsReq.(*bedrockruntime.InvokeModelInput))
ctx, cancel := newAwsInvokeContext()
defer cancel()
awsResp, err := a.AwsClient.InvokeModel(ctx, a.AwsReq.(*bedrockruntime.InvokeModelInput))
if err != nil {
statusCode := getAwsErrorStatusCode(err)
return types.NewOpenAIError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeAwsInvokeError, statusCode), nil

View File

@@ -483,9 +483,11 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
}
}
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {

View File

@@ -13,6 +13,7 @@ import (
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/constant"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -137,7 +138,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
} else if baseModel, level := parseThinkingLevelSuffix(info.UpstreamModelName); level != "" {
} else if baseModel, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
info.UpstreamModelName = baseModel
}
}

View File

@@ -32,6 +32,7 @@ var geminiSupportedMimeTypes = map[string]bool{
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
"image/jpg": true, // support old image/jpeg
"image/webp": true,
"text/plain": true,
"video/mov": true,
@@ -98,6 +99,7 @@ func clampThinkingBudget(modelName string, budget int) int {
// "effort": "high" - Allocates a large portion of tokens for reasoning (approximately 80% of max_tokens)
// "effort": "medium" - Allocates a moderate portion of tokens (approximately 50% of max_tokens)
// "effort": "low" - Allocates a smaller portion of tokens (approximately 20% of max_tokens)
// "effort": "minimal" - Allocates a minimal portion of tokens (approximately 5% of max_tokens)
func clampThinkingBudgetByEffort(modelName string, effort string) int {
isNew25Pro := isNew25ProModel(modelName)
is25FlashLite := is25FlashLiteModel(modelName)
@@ -118,18 +120,12 @@ func clampThinkingBudgetByEffort(modelName string, effort string) int {
maxBudget = maxBudget * 50 / 100
case "low":
maxBudget = maxBudget * 20 / 100
case "minimal":
maxBudget = maxBudget * 5 / 100
}
return clampThinkingBudget(modelName, maxBudget)
}
func parseThinkingLevelSuffix(modelName string) (string, string) {
base, level, ok := reasoning.TrimEffortSuffix(modelName)
if !ok {
return modelName, ""
}
return base, level
}
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo, oaiRequest ...dto.GeneralOpenAIRequest) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName
@@ -186,7 +182,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
ThinkingBudget: common.GetPointer(0),
}
}
} else if _, level := parseThinkingLevelSuffix(modelName); level != "" {
} else if _, level, ok := reasoning.TrimEffortSuffix(info.UpstreamModelName); ok && level != "" {
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
ThinkingLevel: level,
@@ -379,7 +375,7 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
var system_content []string
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
if message.Role == "system" {
if message.Role == "system" || message.Role == "developer" {
system_content = append(system_content, message.StringContent())
continue
} else if message.Role == "tool" || message.Role == "function" {

View File

@@ -67,3 +67,40 @@ type OllamaEmbeddingResponse struct {
Embeddings [][]float64 `json:"embeddings"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
}
type OllamaTagsResponse struct {
Models []OllamaModel `json:"models"`
}
type OllamaModel struct {
Name string `json:"name"`
Size int64 `json:"size"`
Digest string `json:"digest,omitempty"`
ModifiedAt string `json:"modified_at"`
Details OllamaModelDetail `json:"details,omitempty"`
}
type OllamaModelDetail struct {
ParentModel string `json:"parent_model,omitempty"`
Format string `json:"format,omitempty"`
Family string `json:"family,omitempty"`
Families []string `json:"families,omitempty"`
ParameterSize string `json:"parameter_size,omitempty"`
QuantizationLevel string `json:"quantization_level,omitempty"`
}
type OllamaPullRequest struct {
Name string `json:"name"`
Stream bool `json:"stream,omitempty"`
}
type OllamaPullResponse struct {
Status string `json:"status"`
Digest string `json:"digest,omitempty"`
Total int64 `json:"total,omitempty"`
Completed int64 `json:"completed,omitempty"`
}
type OllamaDeleteRequest struct {
Name string `json:"name"`
}

View File

@@ -1,11 +1,13 @@
package ollama
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
@@ -283,3 +285,246 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
service.IOCopyBytesGracefully(c, resp, out)
return usage, nil
}
func FetchOllamaModels(baseURL, apiKey string) ([]OllamaModel, error) {
url := fmt.Sprintf("%s/api/tags", baseURL)
client := &http.Client{}
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
// Ollama 通常不需要 Bearer token但为了兼容性保留
if apiKey != "" {
request.Header.Set("Authorization", "Bearer "+apiKey)
}
response, err := client.Do(request)
if err != nil {
return nil, fmt.Errorf("请求失败: %v", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, fmt.Errorf("服务器返回错误 %d: %s", response.StatusCode, string(body))
}
var tagsResponse OllamaTagsResponse
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %v", err)
}
err = common.Unmarshal(body, &tagsResponse)
if err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
return tagsResponse.Models, nil
}
// 拉取 Ollama 模型 (非流式)
func PullOllamaModel(baseURL, apiKey, modelName string) error {
url := fmt.Sprintf("%s/api/pull", baseURL)
pullRequest := OllamaPullRequest{
Name: modelName,
Stream: false, // 非流式,简化处理
}
requestBody, err := common.Marshal(pullRequest)
if err != nil {
return fmt.Errorf("序列化请求失败: %v", err)
}
client := &http.Client{
Timeout: 30 * 60 * 1000 * time.Millisecond, // 30分钟超时支持大模型
}
request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
if err != nil {
return fmt.Errorf("创建请求失败: %v", err)
}
request.Header.Set("Content-Type", "application/json")
if apiKey != "" {
request.Header.Set("Authorization", "Bearer "+apiKey)
}
response, err := client.Do(request)
if err != nil {
return fmt.Errorf("请求失败: %v", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
}
return nil
}
// 流式拉取 Ollama 模型 (支持进度回调)
func PullOllamaModelStream(baseURL, apiKey, modelName string, progressCallback func(OllamaPullResponse)) error {
url := fmt.Sprintf("%s/api/pull", baseURL)
pullRequest := OllamaPullRequest{
Name: modelName,
Stream: true, // 启用流式
}
requestBody, err := common.Marshal(pullRequest)
if err != nil {
return fmt.Errorf("序列化请求失败: %v", err)
}
client := &http.Client{
Timeout: 60 * 60 * 1000 * time.Millisecond, // 1小时超时支持超大模型
}
request, err := http.NewRequest("POST", url, strings.NewReader(string(requestBody)))
if err != nil {
return fmt.Errorf("创建请求失败: %v", err)
}
request.Header.Set("Content-Type", "application/json")
if apiKey != "" {
request.Header.Set("Authorization", "Bearer "+apiKey)
}
response, err := client.Do(request)
if err != nil {
return fmt.Errorf("请求失败: %v", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("拉取模型失败 %d: %s", response.StatusCode, string(body))
}
// 读取流式响应
scanner := bufio.NewScanner(response.Body)
successful := false
for scanner.Scan() {
line := scanner.Text()
if strings.TrimSpace(line) == "" {
continue
}
var pullResponse OllamaPullResponse
if err := common.Unmarshal([]byte(line), &pullResponse); err != nil {
continue // 忽略解析失败的行
}
if progressCallback != nil {
progressCallback(pullResponse)
}
// 检查是否出现错误或完成
if strings.EqualFold(pullResponse.Status, "error") {
return fmt.Errorf("拉取模型失败: %s", strings.TrimSpace(line))
}
if strings.EqualFold(pullResponse.Status, "success") {
successful = true
break
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取流式响应失败: %v", err)
}
if !successful {
return fmt.Errorf("拉取模型未完成: 未收到成功状态")
}
return nil
}
// 删除 Ollama 模型
func DeleteOllamaModel(baseURL, apiKey, modelName string) error {
url := fmt.Sprintf("%s/api/delete", baseURL)
deleteRequest := OllamaDeleteRequest{
Name: modelName,
}
requestBody, err := common.Marshal(deleteRequest)
if err != nil {
return fmt.Errorf("序列化请求失败: %v", err)
}
client := &http.Client{}
request, err := http.NewRequest("DELETE", url, strings.NewReader(string(requestBody)))
if err != nil {
return fmt.Errorf("创建请求失败: %v", err)
}
request.Header.Set("Content-Type", "application/json")
if apiKey != "" {
request.Header.Set("Authorization", "Bearer "+apiKey)
}
response, err := client.Do(request)
if err != nil {
return fmt.Errorf("请求失败: %v", err)
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return fmt.Errorf("删除模型失败 %d: %s", response.StatusCode, string(body))
}
return nil
}
func FetchOllamaVersion(baseURL, apiKey string) (string, error) {
trimmedBase := strings.TrimRight(baseURL, "/")
if trimmedBase == "" {
return "", fmt.Errorf("baseURL 为空")
}
url := fmt.Sprintf("%s/api/version", trimmedBase)
client := &http.Client{Timeout: 10 * time.Second}
request, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}
if apiKey != "" {
request.Header.Set("Authorization", "Bearer "+apiKey)
}
response, err := client.Do(request)
if err != nil {
return "", fmt.Errorf("请求失败: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("读取响应失败: %v", err)
}
if response.StatusCode != http.StatusOK {
return "", fmt.Errorf("查询版本失败 %d: %s", response.StatusCode, string(body))
}
var versionResp struct {
Version string `json:"version"`
}
if err := json.Unmarshal(body, &versionResp); err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}
if versionResp.Version == "" {
return "", fmt.Errorf("未返回版本信息")
}
return versionResp.Version, nil
}

View File

@@ -208,7 +208,6 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
helper.Done(c)
case types.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysLog("error unmarshalling stream response: " + err.Error())
@@ -221,6 +220,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
for _, resp := range claudeResponses {
_ = helper.ClaudeData(c, *resp)
}
info.ClaudeConvertInfo.Done = true
case types.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse

View File

@@ -186,7 +186,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
usage.CompletionTokens += toolCount * 7
}
applyUsagePostProcessing(info, usage, nil)
applyUsagePostProcessing(info, usage, common.StringToByteSlice(lastStreamData))
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
@@ -597,6 +597,7 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
case constant.ChannelTypeZhipu_v4:
// 智普的cached_tokens在标准位置: usage.prompt_tokens_details.cached_tokens
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
@@ -606,6 +607,19 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
case constant.ChannelTypeMoonshot:
// Moonshot的cached_tokens在非标准位置: choices[].usage.cached_tokens
if usage.PromptTokensDetails.CachedTokens == 0 {
if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
} else if cachedTokens, ok := extractMoonshotCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
usage.PromptTokensDetails.CachedTokens = cachedTokens
} else if usage.PromptCacheHitTokens > 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
}
}
@@ -639,3 +653,32 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
}
return 0, false
}
// extractMoonshotCachedTokensFromBody 从Moonshot的非标准位置提取cached_tokens
// Moonshot的流式响应格式: {"choices":[{"usage":{"cached_tokens":111}}]}
func extractMoonshotCachedTokensFromBody(body []byte) (int, bool) {
if len(body) == 0 {
return 0, false
}
var payload struct {
Choices []struct {
Usage struct {
CachedTokens *int `json:"cached_tokens"`
} `json:"usage"`
} `json:"choices"`
}
if err := common.Unmarshal(body, &payload); err != nil {
return 0, false
}
// 遍历choices查找cached_tokens
for _, choice := range payload.Choices {
if choice.Usage.CachedTokens != nil && *choice.Usage.CachedTokens > 0 {
return *choice.Usage.CachedTokens, true
}
}
return 0, false
}

View File

@@ -192,6 +192,10 @@ func sizeToResolution(size string) (string, error) {
func ProcessAliOtherRatios(aliReq *AliVideoRequest) (map[string]float64, error) {
otherRatios := make(map[string]float64)
aliRatios := map[string]map[string]float64{
"wan2.6-i2v": {
"720P": 1,
"1080P": 1 / 0.6,
},
"wan2.5-t2v-preview": {
"480P": 1,
"720P": 2,
@@ -287,7 +291,9 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
aliReq.Parameters.Size = "1280*720"
}
} else {
if strings.HasPrefix(req.Model, "wan2.5") {
if strings.HasPrefix(req.Model, "wan2.6") {
aliReq.Parameters.Resolution = "1080P"
} else if strings.HasPrefix(req.Model, "wan2.5") {
aliReq.Parameters.Resolution = "1080P"
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
aliReq.Parameters.Resolution = "720P"

View File

@@ -346,7 +346,7 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
}
taskInfo.Code = resPayload.Code
taskInfo.TaskID = resPayload.Data.TaskId
taskInfo.Reason = resPayload.Message
taskInfo.Reason = resPayload.Data.TaskStatusMsg
//任务状态枚举值submitted已提交、processing处理中、succeed成功、failed失败
status := resPayload.Data.TaskStatus
switch status {

View File

@@ -40,6 +40,7 @@ var claudeModelMap = map[string]string{
"claude-opus-4-20250514": "claude-opus-4@20250514",
"claude-opus-4-1-20250805": "claude-opus-4-1@20250805",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
}

View File

@@ -270,6 +270,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// return fmt.Sprintf("%s/api/v3/images/edits", baseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", baseUrl), nil
case constant.RelayModeResponses:
return fmt.Sprintf("%s/api/v3/responses", baseUrl), nil
case constant.RelayModeAudioSpeech:
if baseUrl == channelconstant.ChannelBaseURLs[channelconstant.ChannelTypeVolcEngine] {
return "wss://openspeech.bytedance.com/api/v1/tts/ws_binary", nil
@@ -323,7 +325,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {

View File

@@ -23,7 +23,7 @@ type ConditionOperation struct {
type ParamOperation struct {
Path string `json:"path"`
Mode string `json:"mode"` // delete, set, move, prepend, append
Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
Value interface{} `json:"value"`
KeepOrigin bool `json:"keep_origin"`
From string `json:"from,omitempty"`
@@ -330,8 +330,6 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
}
// 处理路径中的负数索引
opPath := processNegativeIndex(result, op.Path)
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
switch op.Mode {
case "delete":
@@ -342,11 +340,38 @@ func applyOperations(jsonStr string, operations []ParamOperation, conditionConte
}
result, err = sjson.Set(result, opPath, op.Value)
case "move":
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
result, err = moveValue(result, opFrom, opTo)
case "copy":
if op.From == "" || op.To == "" {
return "", fmt.Errorf("copy from/to is required")
}
opFrom := processNegativeIndex(result, op.From)
opTo := processNegativeIndex(result, op.To)
result, err = copyValue(result, opFrom, opTo)
case "prepend":
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
case "append":
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
case "trim_prefix":
result, err = trimStringValue(result, opPath, op.Value, true)
case "trim_suffix":
result, err = trimStringValue(result, opPath, op.Value, false)
case "ensure_prefix":
result, err = ensureStringAffix(result, opPath, op.Value, true)
case "ensure_suffix":
result, err = ensureStringAffix(result, opPath, op.Value, false)
case "trim_space":
result, err = transformStringValue(result, opPath, strings.TrimSpace)
case "to_lower":
result, err = transformStringValue(result, opPath, strings.ToLower)
case "to_upper":
result, err = transformStringValue(result, opPath, strings.ToUpper)
case "replace":
result, err = replaceStringValue(result, opPath, op.From, op.To)
case "regex_replace":
result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
default:
return "", fmt.Errorf("unknown operation: %s", op.Mode)
}
@@ -369,6 +394,14 @@ func moveValue(jsonStr, fromPath, toPath string) (string, error) {
return sjson.Delete(result, fromPath)
}
func copyValue(jsonStr, fromPath, toPath string) (string, error) {
sourceValue := gjson.Get(jsonStr, fromPath)
if !sourceValue.Exists() {
return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
}
return sjson.Set(jsonStr, toPath, sourceValue.Value())
}
func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
current := gjson.Get(jsonStr, path)
switch {
@@ -422,6 +455,88 @@ func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (stri
return sjson.Set(jsonStr, path, newStr)
}
func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return jsonStr, fmt.Errorf("trim value is required")
}
valueStr := fmt.Sprintf("%v", value)
var newStr string
if isPrefix {
newStr = strings.TrimPrefix(current.String(), valueStr)
} else {
newStr = strings.TrimSuffix(current.String(), valueStr)
}
return sjson.Set(jsonStr, path, newStr)
}
func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if value == nil {
return jsonStr, fmt.Errorf("ensure value is required")
}
valueStr := fmt.Sprintf("%v", value)
if valueStr == "" {
return jsonStr, fmt.Errorf("ensure value is required")
}
currentStr := current.String()
if isPrefix {
if strings.HasPrefix(currentStr, valueStr) {
return jsonStr, nil
}
return sjson.Set(jsonStr, path, valueStr+currentStr)
}
if strings.HasSuffix(currentStr, valueStr) {
return jsonStr, nil
}
return sjson.Set(jsonStr, path, currentStr+valueStr)
}
func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
return sjson.Set(jsonStr, path, transform(current.String()))
}
func replaceStringValue(jsonStr, path, from, to string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if from == "" {
return jsonStr, fmt.Errorf("replace from is required")
}
return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
}
func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
current := gjson.Get(jsonStr, path)
if current.Type != gjson.String {
return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
}
if pattern == "" {
return jsonStr, fmt.Errorf("regex pattern is required")
}
re, err := regexp.Compile(pattern)
if err != nil {
return jsonStr, err
}
return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
}
func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
current := gjson.Get(jsonStr, path)
var currentMap, newMap map[string]interface{}

View File

@@ -0,0 +1,791 @@
package common
import (
"encoding/json"
"reflect"
"testing"
)
func TestApplyParamOverrideTrimPrefix(t *testing.T) {
// trim_prefix example:
// {"operations":[{"path":"model","mode":"trim_prefix","value":"openai/"}]}
input := []byte(`{"model":"openai/gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_prefix",
"value": "openai/",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideTrimSuffix(t *testing.T) {
// trim_suffix example:
// {"operations":[{"path":"model","mode":"trim_suffix","value":"-latest"}]}
input := []byte(`{"model":"gpt-4-latest","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_suffix",
"value": "-latest",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideTrimNoop(t *testing.T) {
// trim_prefix no-op example:
// {"operations":[{"path":"model","mode":"trim_prefix","value":"openai/"}]}
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_prefix",
"value": "openai/",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideTrimRequiresValue(t *testing.T) {
// trim_prefix requires value example:
// {"operations":[{"path":"model","mode":"trim_prefix"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_prefix",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideReplace(t *testing.T) {
// replace example:
// {"operations":[{"path":"model","mode":"replace","from":"openai/","to":""}]}
input := []byte(`{"model":"openai/gpt-4o-mini","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "replace",
"from": "openai/",
"to": "",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4o-mini","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideRegexReplace(t *testing.T) {
// regex_replace example:
// {"operations":[{"path":"model","mode":"regex_replace","from":"^gpt-","to":"openai/gpt-"}]}
input := []byte(`{"model":"gpt-4o-mini","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "regex_replace",
"from": "^gpt-",
"to": "openai/gpt-",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"openai/gpt-4o-mini","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideReplaceRequiresFrom(t *testing.T) {
// replace requires from example:
// {"operations":[{"path":"model","mode":"replace"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "replace",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideRegexReplaceRequiresPattern(t *testing.T) {
// regex_replace requires from(pattern) example:
// {"operations":[{"path":"model","mode":"regex_replace"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "regex_replace",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideDelete(t *testing.T) {
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "delete",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
var got map[string]interface{}
if err := json.Unmarshal(out, &got); err != nil {
t.Fatalf("failed to unmarshal output JSON: %v", err)
}
if _, exists := got["temperature"]; exists {
t.Fatalf("expected temperature to be deleted")
}
}
func TestApplyParamOverrideSet(t *testing.T) {
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.1}`, string(out))
}
func TestApplyParamOverrideSetKeepOrigin(t *testing.T) {
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"keep_origin": true,
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideMove(t *testing.T) {
input := []byte(`{"model":"gpt-4","meta":{"x":1}}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "move",
"from": "model",
"to": "meta.model",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"meta":{"x":1,"model":"gpt-4"}}`, string(out))
}
func TestApplyParamOverrideMoveMissingSource(t *testing.T) {
input := []byte(`{"meta":{"x":1}}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "move",
"from": "model",
"to": "meta.model",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverridePrependAppendString(t *testing.T) {
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "prepend",
"value": "openai/",
},
map[string]interface{}{
"path": "model",
"mode": "append",
"value": "-latest",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"openai/gpt-4-latest"}`, string(out))
}
func TestApplyParamOverridePrependAppendArray(t *testing.T) {
input := []byte(`{"arr":[1,2]}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "arr",
"mode": "prepend",
"value": 0,
},
map[string]interface{}{
"path": "arr",
"mode": "append",
"value": []interface{}{3, 4},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"arr":[0,1,2,3,4]}`, string(out))
}
func TestApplyParamOverrideAppendObjectMergeKeepOrigin(t *testing.T) {
input := []byte(`{"obj":{"a":1}}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "obj",
"mode": "append",
"keep_origin": true,
"value": map[string]interface{}{
"a": 2,
"b": 3,
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"obj":{"a":1,"b":3}}`, string(out))
}
func TestApplyParamOverrideAppendObjectMergeOverride(t *testing.T) {
input := []byte(`{"obj":{"a":1}}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "obj",
"mode": "append",
"value": map[string]interface{}{
"a": 2,
"b": 3,
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"obj":{"a":2,"b":3}}`, string(out))
}
func TestApplyParamOverrideConditionORDefault(t *testing.T) {
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "prefix",
"value": "gpt",
},
map[string]interface{}{
"path": "model",
"mode": "prefix",
"value": "claude",
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.1}`, string(out))
}
func TestApplyParamOverrideConditionAND(t *testing.T) {
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"logic": "AND",
"conditions": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "prefix",
"value": "gpt",
},
map[string]interface{}{
"path": "temperature",
"mode": "gt",
"value": 0.5,
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.1}`, string(out))
}
func TestApplyParamOverrideConditionInvert(t *testing.T) {
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "prefix",
"value": "gpt",
"invert": true,
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideConditionPassMissingKey(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "prefix",
"value": "gpt",
"pass_missing_key": true,
},
},
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideConditionFromContext(t *testing.T) {
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "temperature",
"mode": "set",
"value": 0.1,
"conditions": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "prefix",
"value": "gpt",
},
},
},
},
}
ctx := map[string]interface{}{
"model": "gpt-4",
}
out, err := ApplyParamOverride(input, override, ctx)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"temperature":0.1}`, string(out))
}
func TestApplyParamOverrideNegativeIndexPath(t *testing.T) {
input := []byte(`{"arr":[{"model":"a"},{"model":"b"}]}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "arr.-1.model",
"mode": "set",
"value": "c",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"arr":[{"model":"a"},{"model":"c"}]}`, string(out))
}
func TestApplyParamOverrideRegexReplaceInvalidPattern(t *testing.T) {
// regex_replace invalid pattern example:
// {"operations":[{"path":"model","mode":"regex_replace","from":"(","to":"x"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "regex_replace",
"from": "(",
"to": "x",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideCopy(t *testing.T) {
// copy example:
// {"operations":[{"mode":"copy","from":"model","to":"original_model"}]}
input := []byte(`{"model":"gpt-4","temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy",
"from": "model",
"to": "original_model",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4","original_model":"gpt-4","temperature":0.7}`, string(out))
}
func TestApplyParamOverrideCopyMissingSource(t *testing.T) {
// copy missing source example:
// {"operations":[{"mode":"copy","from":"model","to":"original_model"}]}
input := []byte(`{"temperature":0.7}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy",
"from": "model",
"to": "original_model",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideCopyRequiresFromTo(t *testing.T) {
// copy requires from/to example:
// {"operations":[{"mode":"copy"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"mode": "copy",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideEnsurePrefix(t *testing.T) {
// ensure_prefix example:
// {"operations":[{"path":"model","mode":"ensure_prefix","value":"openai/"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "ensure_prefix",
"value": "openai/",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"openai/gpt-4"}`, string(out))
}
func TestApplyParamOverrideEnsurePrefixNoop(t *testing.T) {
// ensure_prefix no-op example:
// {"operations":[{"path":"model","mode":"ensure_prefix","value":"openai/"}]}
input := []byte(`{"model":"openai/gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "ensure_prefix",
"value": "openai/",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"openai/gpt-4"}`, string(out))
}
func TestApplyParamOverrideEnsureSuffix(t *testing.T) {
// ensure_suffix example:
// {"operations":[{"path":"model","mode":"ensure_suffix","value":"-latest"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "ensure_suffix",
"value": "-latest",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4-latest"}`, string(out))
}
func TestApplyParamOverrideEnsureSuffixNoop(t *testing.T) {
// ensure_suffix no-op example:
// {"operations":[{"path":"model","mode":"ensure_suffix","value":"-latest"}]}
input := []byte(`{"model":"gpt-4-latest"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "ensure_suffix",
"value": "-latest",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4-latest"}`, string(out))
}
func TestApplyParamOverrideEnsureRequiresValue(t *testing.T) {
// ensure_prefix requires value example:
// {"operations":[{"path":"model","mode":"ensure_prefix"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "ensure_prefix",
},
},
}
_, err := ApplyParamOverride(input, override, nil)
if err == nil {
t.Fatalf("expected error, got nil")
}
}
func TestApplyParamOverrideTrimSpace(t *testing.T) {
// trim_space example:
// {"operations":[{"path":"model","mode":"trim_space"}]}
input := []byte("{\"model\":\" gpt-4 \\n\"}")
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "trim_space",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4"}`, string(out))
}
func TestApplyParamOverrideToLower(t *testing.T) {
// to_lower example:
// {"operations":[{"path":"model","mode":"to_lower"}]}
input := []byte(`{"model":"GPT-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "to_lower",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"gpt-4"}`, string(out))
}
func TestApplyParamOverrideToUpper(t *testing.T) {
// to_upper example:
// {"operations":[{"path":"model","mode":"to_upper"}]}
input := []byte(`{"model":"gpt-4"}`)
override := map[string]interface{}{
"operations": []interface{}{
map[string]interface{}{
"path": "model",
"mode": "to_upper",
},
},
}
out, err := ApplyParamOverride(input, override, nil)
if err != nil {
t.Fatalf("ApplyParamOverride returned error: %v", err)
}
assertJSONEqual(t, `{"model":"GPT-4"}`, string(out))
}
func assertJSONEqual(t *testing.T, want, got string) {
t.Helper()
var wantObj interface{}
var gotObj interface{}
if err := json.Unmarshal([]byte(want), &wantObj); err != nil {
t.Fatalf("failed to unmarshal want JSON: %v", err)
}
if err := json.Unmarshal([]byte(got), &gotObj); err != nil {
t.Fatalf("failed to unmarshal got JSON: %v", err)
}
if !reflect.DeepEqual(wantObj, gotObj) {
t.Fatalf("json not equal\nwant: %s\ngot: %s", want, got)
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/QuantumNous/new-api/types"
"github.com/shopspring/decimal"
@@ -181,22 +182,25 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return newApiErr
}
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
var containAudioTokens = usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0
var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName)
if containAudioTokens && containsAudioRatios {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
}
return nil
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.GetEstimatePromptTokens(),
CompletionTokens: 0,
TotalTokens: relayInfo.GetEstimatePromptTokens(),
}
extraContent += "(可能是请求出错)"
extraContent = append(extraContent, "上游无计费信息")
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
@@ -246,8 +250,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s调用花费 %s",
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s调用花费 %s",
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
}
} else if strings.HasSuffix(modelName, "search-preview") {
// search-preview 模型不支持 response api
@@ -258,8 +262,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s调用花费 %s",
searchContextSize, dWebSearchQuota.String())
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s调用花费 %s",
searchContextSize, dWebSearchQuota.String()))
}
// claude web search tool 计费
var dClaudeWebSearchQuota decimal.Decimal
@@ -269,8 +273,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
}
// file search tool 计费
var dFileSearchQuota decimal.Decimal
@@ -281,8 +285,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
fileSearchTool.CallCount, dFileSearchQuota.String())
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
fileSearchTool.CallCount, dFileSearchQuota.String()))
}
}
var dImageGenerationCallQuota decimal.Decimal
@@ -290,7 +294,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
if ctx.GetBool("image_generation_call") {
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
}
var quotaCalculateDecimal decimal.Decimal
@@ -300,14 +304,20 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
if !relayInfo.PriceData.UsePrice {
baseTokens := dPromptTokens
// 减去 cached tokens
// Anthropic API 的 input_tokens 已经不包含缓存 tokens不需要减去
// OpenAI/OpenRouter 等 API 的 prompt_tokens 包含缓存 tokens需要减去
var cachedTokensWithRatio decimal.Decimal
if !dCacheTokens.IsZero() {
baseTokens = baseTokens.Sub(dCacheTokens)
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
baseTokens = baseTokens.Sub(dCacheTokens)
}
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
}
var dCachedCreationTokensWithRatio decimal.Decimal
if !dCachedCreationTokens.IsZero() {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
if relayInfo.ChannelType != constant.ChannelTypeAnthropic {
baseTokens = baseTokens.Sub(dCachedCreationTokens)
}
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
}
@@ -325,7 +335,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
// 重新计算 base tokens
baseTokens = baseTokens.Sub(dAudioTokens)
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
}
}
promptQuota := baseTokens.Add(cachedTokensWithRatio).
@@ -350,17 +360,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
// 添加 image generation call 计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
if len(relayInfo.PriceData.OtherRatios) > 0 {
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
dOtherRatio := decimal.NewFromFloat(otherRatio)
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
}
}
quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens
var logContent string
//var logContent string
// record all the consume log even if quota is 0
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
@@ -399,15 +417,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf("模型 %s", modelName)
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
}
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
logModel = "gpt-4o-gizmo-*"
logContent += fmt.Sprintf("模型 %s", modelName)
}
if extraContent != "" {
logContent += ", " + extraContent
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
}
logContent := strings.Join(extraContent, ", ")
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 {
other["image"] = true

View File

@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}

View File

@@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return openaiErr
}
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}
@@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
return openaiErr
}
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}

View File

@@ -110,8 +110,6 @@ func GetAndValidateEmbeddingRequest(c *gin.Context, relayMode int) (*dto.Embeddi
return embeddingRequest, nil
}
// GetAndValidateResponsesRequest parses the HTTP request body into an OpenAIResponsesRequest and ensures the Model field is provided.
// It returns the parsed request, or an error if the body cannot be parsed or the Model is empty.
func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest, error) {
request := &dto.OpenAIResponsesRequest{}
err := common.UnmarshalBodyReusable(c, request)
@@ -121,6 +119,9 @@ func GetAndValidateResponsesRequest(c *gin.Context) (*dto.OpenAIResponsesRequest
if request.Model == "" {
return nil, errors.New("model is required")
}
if request.Input == nil {
return nil, errors.New("input is required")
}
return request, nil
}
@@ -323,4 +324,4 @@ func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatch
return nil, err
}
return request, nil
}
}

View File

@@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
quality = "hd"
}
var logContent string
var logContent []string
if len(request.Size) > 0 {
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
if len(quality) > 0 {
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
}
if request.N > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N))
}
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
return nil
}

View File

@@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
return nil
}

View File

@@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else {
postConsumeQuota(c, info, usage.(*dto.Usage), "")
postConsumeQuota(c, info, usage.(*dto.Usage))
}
return nil
}

View File

@@ -93,6 +93,10 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/2fa/enable", controller.Enable2FA)
selfRoute.POST("/2fa/disable", controller.Disable2FA)
selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
// Check-in routes
selfRoute.GET("/checkin", controller.GetCheckinStatus)
selfRoute.POST("/checkin", middleware.TurnstileCheck(), controller.DoCheckin)
}
adminRoute := userRoute.Group("/")
@@ -152,6 +156,10 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/ollama/pull", controller.OllamaPullModel)
channelRoute.POST("/ollama/pull/stream", controller.OllamaPullModelStream)
channelRoute.DELETE("/ollama/delete", controller.OllamaDeleteModel)
channelRoute.GET("/ollama/version/:id", controller.OllamaVersion)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
channelRoute.GET("/tag/models", controller.GetTagModels)
channelRoute.POST("/copy/:id", controller.CopyChannel)
@@ -256,5 +264,31 @@ func SetApiRouter(router *gin.Engine) {
modelsRoute.PUT("/", controller.UpdateModelMeta)
modelsRoute.DELETE("/:id", controller.DeleteModelMeta)
}
// Deployments (model deployment management)
deploymentsRoute := apiRouter.Group("/deployments")
deploymentsRoute.Use(middleware.AdminAuth())
{
deploymentsRoute.GET("/settings", controller.GetModelDeploymentSettings)
deploymentsRoute.POST("/settings/test-connection", controller.TestIoNetConnection)
deploymentsRoute.GET("/", controller.GetAllDeployments)
deploymentsRoute.GET("/search", controller.SearchDeployments)
deploymentsRoute.POST("/test-connection", controller.TestIoNetConnection)
deploymentsRoute.GET("/hardware-types", controller.GetHardwareTypes)
deploymentsRoute.GET("/locations", controller.GetLocations)
deploymentsRoute.GET("/available-replicas", controller.GetAvailableReplicas)
deploymentsRoute.POST("/price-estimation", controller.GetPriceEstimation)
deploymentsRoute.GET("/check-name", controller.CheckClusterNameAvailability)
deploymentsRoute.POST("/", controller.CreateDeployment)
deploymentsRoute.GET("/:id", controller.GetDeployment)
deploymentsRoute.GET("/:id/logs", controller.GetDeploymentLogs)
deploymentsRoute.GET("/:id/containers", controller.ListDeploymentContainers)
deploymentsRoute.GET("/:id/containers/:container_id", controller.GetContainerDetails)
deploymentsRoute.PUT("/:id", controller.UpdateDeployment)
deploymentsRoute.PUT("/:id/name", controller.UpdateDeploymentName)
deploymentsRoute.POST("/:id/extend", controller.ExtendDeployment)
deploymentsRoute.DELETE("/:id", controller.DeleteDeployment)
}
}
}

View File

@@ -389,25 +389,29 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
}
idx := blockIndex
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: toolCall.ID,
Type: "tool_use",
Name: toolCall.Function.Name,
Input: map[string]interface{}{},
},
})
if toolCall.Function.Name != "" {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Id: toolCall.ID,
Type: "tool_use",
Name: toolCall.Function.Name,
Input: map[string]interface{}{},
},
})
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "input_json_delta",
PartialJson: &toolCall.Function.Arguments,
},
})
if len(toolCall.Function.Arguments) > 0 {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &idx,
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "input_json_delta",
PartialJson: &toolCall.Function.Arguments,
},
})
}
info.ClaudeConvertInfo.Index = blockIndex
}
@@ -670,20 +674,21 @@ func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycomm
var tools []dto.ToolCallRequest
for _, tool := range geminiRequest.GetTools() {
if tool.FunctionDeclarations != nil {
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
if ok {
for _, function := range functionDeclarations {
openAITool := dto.ToolCallRequest{
Type: "function",
Function: dto.FunctionRequest{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
}
tools = append(tools, openAITool)
functionDeclarations, err := common.Any2Type[[]dto.FunctionRequest](tool.FunctionDeclarations)
if err != nil {
common.SysError(fmt.Sprintf("failed to parse gemini function declarations: %v (type=%T)", err, tool.FunctionDeclarations))
continue
}
for _, function := range functionDeclarations {
openAITool := dto.ToolCallRequest{
Type: "function",
Function: dto.FunctionRequest{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
}
tools = append(tools, openAITool)
}
}
}

View File

@@ -81,33 +81,26 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
return claudeErr
}
// RelayErrorHandler converts an HTTP error response into a structured types.NewAPIError.
// It returns a NewAPIError initialized with the response status code and one of:
// - an Err describing an absent or unreadable body,
// - an Err containing the unmarshaled error message (or status + raw body when showBodyWhenFail is true), or
// - an embedded OpenAI-style error when the response body contains a compatible error object.
// The returned NewAPIError's status code reflects resp.StatusCode.
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
if resp.Body == nil {
newApiErr.Err = errors.New("response body is nil")
return
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
CloseResponseBodyGracefully(resp)
newApiErr.Err = fmt.Errorf("read response body failed: %w", err)
return
}
CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
buildErrWithBody := func(message string) error {
if message == "" {
return fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
}
return fmt.Errorf("bad response status code %d, message: %s, body: %s", resp.StatusCode, message, string(responseBody))
}
err = common.Unmarshal(responseBody, &errResponse)
if err != nil {
if showBodyWhenFail {
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
newApiErr.Err = buildErrWithBody("")
} else {
logger.LogError(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
@@ -120,10 +113,16 @@ func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFai
oaiError := errResponse.TryToOpenAIError()
if oaiError != nil {
newApiErr = types.WithOpenAIError(*oaiError, resp.StatusCode)
if showBodyWhenFail {
newApiErr.Err = buildErrWithBody(newApiErr.Error())
}
return
}
}
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
if showBodyWhenFail {
newApiErr.Err = buildErrWithBody(newApiErr.Error())
}
return
}
@@ -169,4 +168,4 @@ func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
}
return taskError
}
}

View File

@@ -57,4 +57,5 @@ func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
if err != nil {
logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
}
c.Writer.Flush()
}

View File

@@ -38,6 +38,7 @@ func InitHttpClient() {
MaxIdleConns: common.RelayMaxIdleConns,
MaxIdleConnsPerHost: common.RelayMaxIdleConnsPerHost,
ForceAttemptHTTP2: true,
Proxy: http.ProxyFromEnvironment, // Support HTTP_PROXY, HTTPS_PROXY, NO_PROXY env vars
}
if common.RelayTimeout == 0 {
@@ -81,6 +82,9 @@ func ResetProxyClientCache() {
// NewProxyHttpClient 创建支持代理的 HTTP 客户端
func NewProxyHttpClient(proxyURL string) (*http.Client, error) {
if proxyURL == "" {
if client := GetHttpClient(); client != nil {
return client, nil
}
return http.DefaultClient, nil
}

View File

@@ -95,7 +95,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return err
}
token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
token, err := model.GetTokenByKey(strings.TrimPrefix(relayInfo.TokenKey, "sk-"), false)
if err != nil {
return err
}

View File

@@ -0,0 +1,37 @@
package operation_setting
import "github.com/QuantumNous/new-api/setting/config"
// CheckinSetting 签到功能配置
type CheckinSetting struct {
Enabled bool `json:"enabled"` // 是否启用签到功能
MinQuota int `json:"min_quota"` // 签到最小额度奖励
MaxQuota int `json:"max_quota"` // 签到最大额度奖励
}
// 默认配置
var checkinSetting = CheckinSetting{
Enabled: false, // 默认关闭
MinQuota: 1000, // 默认最小额度 1000 (约 0.002 USD)
MaxQuota: 10000, // 默认最大额度 10000 (约 0.02 USD)
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("checkin_setting", &checkinSetting)
}
// GetCheckinSetting 获取签到配置
func GetCheckinSetting() *CheckinSetting {
return &checkinSetting
}
// IsCheckinEnabled 是否启用签到功能
func IsCheckinEnabled() bool {
return checkinSetting.Enabled
}
// GetCheckinQuotaRange 获取签到额度范围
func GetCheckinQuotaRange() (min, max int) {
return checkinSetting.MinQuota, checkinSetting.MaxQuota
}

View File

@@ -7,7 +7,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
)
// from songquanpeng/one-api
@@ -312,6 +311,10 @@ var defaultAudioCompletionRatio = map[string]float64{
"gpt-4o-realtime": 2,
"gpt-4o-mini-realtime": 2,
"gpt-4o-mini-tts": 1,
"tts-1": 0,
"tts-1-hd": 0,
"tts-1-1106": 0,
"tts-1-hd-1106": 0,
}
var (
@@ -657,7 +660,7 @@ func GetAudioRatio(name string) float64 {
if ratio, ok := audioRatioMap[name]; ok {
return ratio
}
return 20
return 1
}
func GetAudioCompletionRatio(name string) float64 {
@@ -668,7 +671,23 @@ func GetAudioCompletionRatio(name string) float64 {
return ratio
}
return 2
return 1
}
func ContainsAudioRatio(name string) bool {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
_, ok := audioRatioMap[name]
return ok
}
func ContainsAudioCompletionRatio(name string) bool {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
name = FormatMatchingModelName(name)
_, ok := audioCompletionRatioMap[name]
return ok
}
func ModelRatio2JSONString() string {
@@ -746,16 +765,6 @@ func UpdateAudioRatioByJSONString(jsonStr string) error {
return nil
}
func GetAudioRatioCopy() map[string]float64 {
audioRatioMapMutex.RLock()
defer audioRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(audioRatioMap))
for k, v := range audioRatioMap {
copyMap[k] = v
}
return copyMap
}
func AudioCompletionRatio2JSONString() string {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
@@ -778,16 +787,6 @@ func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
return nil
}
func GetAudioCompletionRatioCopy() map[string]float64 {
audioCompletionRatioMapMutex.RLock()
defer audioCompletionRatioMapMutex.RUnlock()
copyMap := make(map[string]float64, len(audioCompletionRatioMap))
for k, v := range audioCompletionRatioMap {
copyMap[k] = v
}
return copyMap
}
func GetModelRatioCopy() map[string]float64 {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
@@ -829,10 +828,6 @@ func FormatMatchingModelName(name string) string {
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
}
if base, _, ok := reasoning.TrimEffortSuffix(name); ok {
name = base
}
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}

View File

@@ -6,7 +6,7 @@ import (
"github.com/samber/lo"
)
var EffortSuffixes = []string{"-high", "-medium", "-low"}
var EffortSuffixes = []string{"-high", "-medium", "-low", "-minimal"}
// TrimEffortSuffix -> modelName level(low) exists
func TrimEffortSuffix(modelName string) (string, string, bool) {

View File

@@ -1,6 +1,7 @@
package types
import (
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -10,10 +11,11 @@ import (
)
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
Metadata json.RawMessage `json:"metadata,omitempty"`
}
type ClaudeError struct {
@@ -92,6 +94,7 @@ type NewAPIError struct {
errorType ErrorType
errorCode ErrorCode
StatusCode int
Metadata json.RawMessage
}
// Unwrap enables errors.Is / errors.As to work with NewAPIError by exposing the underlying error.
@@ -301,6 +304,13 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIError
Err: errors.New(openAIError.Message),
errorCode: ErrorCode(code),
}
// OpenRouter
if len(openAIError.Metadata) > 0 {
openAIError.Message = fmt.Sprintf("%s (%s)", openAIError.Message, openAIError.Metadata)
e.Metadata = openAIError.Metadata
e.RelayError = openAIError
e.Err = errors.New(openAIError.Message)
}
for _, op := range ops {
op(e)
}

View File

@@ -26,12 +26,22 @@ type PriceData struct {
GroupRatioInfo GroupRatioInfo
}
func (p *PriceData) AddOtherRatio(key string, ratio float64) {
if p.OtherRatios == nil {
p.OtherRatios = make(map[string]float64)
}
if ratio <= 0 {
return
}
p.OtherRatios[key] = ratio
}
type PerCallPriceData struct {
ModelPrice float64
Quota int
GroupRatioInfo GroupRatioInfo
}
func (p PriceData) ToSetting() string {
func (p *PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
}

View File

@@ -25,7 +25,9 @@ export default defineConfig({
"zh",
"en",
"fr",
"ru"
"ru",
"ja",
"vi"
],
extract: {
input: [

View File

@@ -42,6 +42,7 @@ import Midjourney from './pages/Midjourney';
import Pricing from './pages/Pricing';
import Task from './pages/Task';
import ModelPage from './pages/Model';
import ModelDeploymentPage from './pages/ModelDeployment';
import Playground from './pages/Playground';
import OAuth2Callback from './components/auth/OAuth2Callback';
import PersonalSetting from './components/settings/PersonalSetting';
@@ -108,6 +109,14 @@ function App() {
</AdminRoute>
}
/>
<Route
path='/console/deployment'
element={
<AdminRoute>
<ModelDeploymentPage />
</AdminRoute>
}
/>
<Route
path='/console/channel'
element={

View File

@@ -59,6 +59,11 @@ import { SiDiscord }from 'react-icons/si';
const LoginForm = () => {
let navigate = useNavigate();
const { t } = useTranslation();
const githubButtonTextKeyByState = {
idle: '使用 GitHub 继续',
redirecting: '正在跳转 GitHub...',
timeout: '请求超时,请刷新页面后重新发起 GitHub 登录',
};
const [inputs, setInputs] = useState({
username: '',
password: '',
@@ -90,9 +95,10 @@ const LoginForm = () => {
const [agreedToTerms, setAgreedToTerms] = useState(false);
const [hasUserAgreement, setHasUserAgreement] = useState(false);
const [hasPrivacyPolicy, setHasPrivacyPolicy] = useState(false);
const [githubButtonText, setGithubButtonText] = useState('使用 GitHub 继续');
const [githubButtonState, setGithubButtonState] = useState('idle');
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
const githubTimeoutRef = useRef(null);
const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]);
const logo = getLogo();
const systemName = getSystemName();
@@ -284,13 +290,13 @@ const LoginForm = () => {
}
setGithubLoading(true);
setGithubButtonDisabled(true);
setGithubButtonText(t('正在跳转 GitHub...'));
setGithubButtonState('redirecting');
if (githubTimeoutRef.current) {
clearTimeout(githubTimeoutRef.current);
}
githubTimeoutRef.current = setTimeout(() => {
setGithubLoading(false);
setGithubButtonText(t('请求超时,请刷新页面后重新发起 GitHub 登录'));
setGithubButtonState('timeout');
setGithubButtonDisabled(true);
}, 20000);
try {

View File

@@ -57,6 +57,11 @@ import { SiDiscord } from 'react-icons/si';
const RegisterForm = () => {
let navigate = useNavigate();
const { t } = useTranslation();
const githubButtonTextKeyByState = {
idle: '使用 GitHub 继续',
redirecting: '正在跳转 GitHub...',
timeout: '请求超时,请刷新页面后重新发起 GitHub 登录',
};
const [inputs, setInputs] = useState({
username: '',
password: '',
@@ -88,9 +93,10 @@ const RegisterForm = () => {
const [agreedToTerms, setAgreedToTerms] = useState(false);
const [hasUserAgreement, setHasUserAgreement] = useState(false);
const [hasPrivacyPolicy, setHasPrivacyPolicy] = useState(false);
const [githubButtonText, setGithubButtonText] = useState('使用 GitHub 继续');
const [githubButtonState, setGithubButtonState] = useState('idle');
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
const githubTimeoutRef = useRef(null);
const githubButtonText = t(githubButtonTextKeyByState[githubButtonState]);
const logo = getLogo();
const systemName = getSystemName();
@@ -251,13 +257,13 @@ const RegisterForm = () => {
}
setGithubLoading(true);
setGithubButtonDisabled(true);
setGithubButtonText(t('正在跳转 GitHub...'));
setGithubButtonState('redirecting');
if (githubTimeoutRef.current) {
clearTimeout(githubTimeoutRef.current);
}
githubTimeoutRef.current = setTimeout(() => {
setGithubLoading(false);
setGithubButtonText(t('请求超时,请刷新页面后重新发起 GitHub 登录'));
setGithubButtonState('timeout');
setGithubButtonDisabled(true);
}, 20000);
try {

View File

@@ -45,6 +45,7 @@ const routerMap = {
pricing: '/pricing',
task: '/console/task',
models: '/console/models',
deployment: '/console/deployment',
playground: '/console/playground',
personal: '/console/personal',
};
@@ -157,6 +158,12 @@ const SiderBar = ({ onNavigate = () => {} }) => {
to: '/console/models',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('模型部署'),
itemKey: 'deployment',
to: '/deployment',
className: isAdmin() ? '' : 'tableHiddle',
},
{
text: t('兑换码管理'),
itemKey: 'redemption',

View File

@@ -52,7 +52,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width: isMobile ? 40 : width, height }}
/>
}
@@ -71,7 +70,7 @@ const SkeletonWrapper = ({
loading={true}
active
placeholder={
<Skeleton.Avatar active size='extra-small' className='shadow-sm' />
<Skeleton.Avatar size='extra-small' className='shadow-sm' />
}
/>
<div className='ml-1.5 mr-1'>
@@ -80,7 +79,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width: isMobile ? 15 : width, height: 12 }}
/>
}
@@ -98,7 +96,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Image
active
className={`absolute inset-0 !rounded-full ${className}`}
style={{ width: '100%', height: '100%' }}
/>
@@ -113,7 +110,7 @@ const SkeletonWrapper = ({
<Skeleton
loading={true}
active
placeholder={<Skeleton.Title active style={{ width, height: 24 }} />}
placeholder={<Skeleton.Title style={{ width, height: 24 }} />}
/>
);
};
@@ -125,7 +122,7 @@ const SkeletonWrapper = ({
<Skeleton
loading={true}
active
placeholder={<Skeleton.Title active style={{ width, height }} />}
placeholder={<Skeleton.Title style={{ width, height }} />}
/>
</div>
);
@@ -140,7 +137,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width, height, borderRadius: 9999 }}
/>
}
@@ -164,7 +160,7 @@ const SkeletonWrapper = ({
loading={true}
active
placeholder={
<Skeleton.Avatar active size='extra-small' shape='square' />
<Skeleton.Avatar size='extra-small' shape='square' />
}
/>
</div>
@@ -174,7 +170,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width: width || 80, height: height || 14 }}
/>
}
@@ -191,10 +186,7 @@ const SkeletonWrapper = ({
loading={true}
active
placeholder={
<Skeleton.Title
active
style={{ width: width || 60, height: height || 12 }}
/>
<Skeleton.Title style={{ width: width || 60, height: height || 12 }} />
}
/>
</div>
@@ -217,7 +209,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Avatar
active
shape='square'
style={{ width: ICON_SIZE, height: ICON_SIZE }}
/>
@@ -231,7 +222,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width: labelWidth, height: TEXT_HEIGHT }}
/>
}
@@ -269,7 +259,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Avatar
active
shape='square'
style={{ width: ICON_SIZE, height: ICON_SIZE }}
/>
@@ -329,7 +318,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
/>
}
@@ -350,7 +338,6 @@ const SkeletonWrapper = ({
active
placeholder={
<Skeleton.Title
active
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
/>
}

View File

@@ -0,0 +1,412 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Card, Button, Typography } from '@douyinfe/semi-ui';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';
import { Settings, Server, AlertCircle, WifiOff } from 'lucide-react';
const { Title, Text } = Typography;
const DeploymentAccessGuard = ({
children,
loading,
isEnabled,
connectionLoading,
connectionOk,
connectionError,
onRetry,
}) => {
const { t } = useTranslation();
const navigate = useNavigate();
const handleGoToSettings = () => {
navigate('/console/setting?tab=model-deployment');
};
if (loading) {
return (
<div className='mt-[60px] px-2'>
<Card loading={true} style={{ minHeight: '400px' }}>
<div style={{ textAlign: 'center', padding: '50px 0' }}>
<Text type='secondary'>{t('加载设置中...')}</Text>
</div>
</Card>
</div>
);
}
if (!isEnabled) {
return (
<div
className='mt-[60px] px-4'
style={{
minHeight: 'calc(100vh - 60px)',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
}}
>
<div
style={{
maxWidth: '600px',
width: '100%',
textAlign: 'center',
padding: '0 20px',
}}
>
<Card
style={{
padding: '60px 40px',
borderRadius: '16px',
border: '1px solid var(--semi-color-border)',
boxShadow: '0 4px 20px rgba(0, 0, 0, 0.08)',
background:
'linear-gradient(135deg, var(--semi-color-bg-0) 0%, var(--semi-color-fill-0) 100%)',
}}
>
{/* 图标区域 */}
<div style={{ marginBottom: '32px' }}>
<div
style={{
display: 'inline-flex',
alignItems: 'center',
justifyContent: 'center',
width: '120px',
height: '120px',
borderRadius: '50%',
background:
'linear-gradient(135deg, rgba(var(--semi-orange-4), 0.15) 0%, rgba(var(--semi-orange-5), 0.1) 100%)',
border: '3px solid rgba(var(--semi-orange-4), 0.3)',
marginBottom: '24px',
}}
>
<AlertCircle size={56} color='var(--semi-color-warning)' />
</div>
</div>
{/* 标题区域 */}
<div style={{ marginBottom: '24px' }}>
<Title
heading={2}
style={{
color: 'var(--semi-color-text-0)',
margin: '0 0 12px 0',
fontSize: '28px',
fontWeight: '700',
}}
>
{t('模型部署服务未启用')}
</Title>
<Text
style={{
fontSize: '18px',
lineHeight: '1.6',
color: 'var(--semi-color-text-1)',
display: 'block',
}}
>
{t('访问模型部署功能需要先启用 io.net 部署服务')}
</Text>
</div>
{/* 配置要求区域 */}
<div
style={{
backgroundColor: 'var(--semi-color-bg-1)',
padding: '24px',
borderRadius: '12px',
border: '1px solid var(--semi-color-border)',
margin: '32px 0',
boxShadow: '0 2px 8px rgba(0, 0, 0, 0.04)',
}}
>
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
gap: '12px',
marginBottom: '16px',
}}
>
<div
style={{
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
width: '32px',
height: '32px',
borderRadius: '8px',
backgroundColor: 'rgba(var(--semi-blue-4), 0.15)',
}}
>
<Server size={20} color='var(--semi-color-primary)' />
</div>
<Text
strong
style={{
fontSize: '16px',
color: 'var(--semi-color-text-0)',
}}
>
{t('需要配置的项目')}
</Text>
</div>
<div
style={{
display: 'flex',
flexDirection: 'column',
gap: '12px',
alignItems: 'flex-start',
textAlign: 'left',
maxWidth: '320px',
margin: '0 auto',
}}
>
<div
style={{ display: 'flex', alignItems: 'center', gap: '12px' }}
>
<div
style={{
width: '6px',
height: '6px',
borderRadius: '50%',
backgroundColor: 'var(--semi-color-primary)',
flexShrink: 0,
}}
></div>
<Text
style={{
fontSize: '15px',
color: 'var(--semi-color-text-1)',
}}
>
{t('启用 io.net 部署开关')}
</Text>
</div>
<div
style={{ display: 'flex', alignItems: 'center', gap: '12px' }}
>
<div
style={{
width: '6px',
height: '6px',
borderRadius: '50%',
backgroundColor: 'var(--semi-color-primary)',
flexShrink: 0,
}}
></div>
<Text
style={{
fontSize: '15px',
color: 'var(--semi-color-text-1)',
}}
>
{t('配置有效的 io.net API Key')}
</Text>
</div>
</div>
</div>
{/* 操作链接区域 */}
<div style={{ marginBottom: '20px' }}>
<div
onClick={handleGoToSettings}
style={{
display: 'inline-flex',
alignItems: 'center',
gap: '8px',
cursor: 'pointer',
padding: '12px 24px',
borderRadius: '8px',
fontSize: '16px',
fontWeight: '500',
color: 'var(--semi-color-primary)',
background: 'var(--semi-color-fill-0)',
border: '1px solid var(--semi-color-border)',
transition: 'all 0.2s ease',
textDecoration: 'none',
}}
onMouseEnter={(e) => {
e.currentTarget.style.background = 'var(--semi-color-fill-1)';
e.currentTarget.style.transform = 'translateY(-1px)';
e.currentTarget.style.boxShadow =
'0 2px 8px rgba(0, 0, 0, 0.1)';
}}
onMouseLeave={(e) => {
e.currentTarget.style.background = 'var(--semi-color-fill-0)';
e.currentTarget.style.transform = 'translateY(0)';
e.currentTarget.style.boxShadow = 'none';
}}
>
<Settings size={18} />
{t('前往设置页面')}
</div>
</div>
{/* 底部提示 */}
<Text
type='tertiary'
style={{
fontSize: '14px',
color: 'var(--semi-color-text-2)',
lineHeight: '1.5',
}}
>
{t('配置完成后刷新页面即可使用模型部署功能')}
</Text>
</Card>
</div>
</div>
);
}
if (connectionLoading || (connectionOk === null && !connectionError)) {
return (
<div className='mt-[60px] px-2'>
<Card loading={true} style={{ minHeight: '400px' }}>
<div style={{ textAlign: 'center', padding: '50px 0' }}>
<Text type='secondary'>{t('正在检查 io.net 连接...')}</Text>
</div>
</Card>
</div>
);
}
if (connectionOk === false) {
const isExpired = connectionError?.type === 'expired';
const title = isExpired ? t('接口密钥已过期') : t('无法连接 io.net');
const description = isExpired
? t('当前 API 密钥已过期,请在设置中更新。')
: t('当前配置无法连接到 io.net。');
const detail = connectionError?.message || '';
return (
<div
className='mt-[60px] px-4'
style={{
minHeight: 'calc(100vh - 60px)',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
}}
>
<div
style={{
maxWidth: '600px',
width: '100%',
textAlign: 'center',
padding: '0 20px',
}}
>
<Card
style={{
padding: '60px 40px',
borderRadius: '16px',
border: '1px solid var(--semi-color-border)',
boxShadow: '0 4px 20px rgba(0, 0, 0, 0.08)',
background:
'linear-gradient(135deg, var(--semi-color-bg-0) 0%, var(--semi-color-fill-0) 100%)',
}}
>
<div style={{ marginBottom: '32px' }}>
<div
style={{
display: 'inline-flex',
alignItems: 'center',
justifyContent: 'center',
width: '120px',
height: '120px',
borderRadius: '50%',
background:
'linear-gradient(135deg, rgba(var(--semi-red-4), 0.15) 0%, rgba(var(--semi-red-5), 0.1) 100%)',
border: '3px solid rgba(var(--semi-red-4), 0.3)',
marginBottom: '24px',
}}
>
<WifiOff size={56} color='var(--semi-color-danger)' />
</div>
</div>
<div style={{ marginBottom: '24px' }}>
<Title
heading={2}
style={{
color: 'var(--semi-color-text-0)',
margin: '0 0 12px 0',
fontSize: '28px',
fontWeight: '700',
}}
>
{title}
</Title>
<Text
style={{
fontSize: '18px',
lineHeight: '1.6',
color: 'var(--semi-color-text-1)',
display: 'block',
}}
>
{description}
</Text>
{detail ? (
<Text
type='tertiary'
style={{
fontSize: '14px',
lineHeight: '1.5',
display: 'block',
marginTop: '8px',
}}
>
{detail}
</Text>
) : null}
</div>
<div
style={{ display: 'flex', gap: '12px', justifyContent: 'center' }}
>
<Button
type='primary'
icon={<Settings size={18} />}
onClick={handleGoToSettings}
>
{t('前往设置')}
</Button>
{onRetry ? (
<Button type='tertiary' onClick={onRetry}>
{t('重试连接')}
</Button>
) : null}
</div>
</Card>
</div>
</div>
);
}
return children;
};
export default DeploymentAccessGuard;

View File

@@ -0,0 +1,85 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useEffect, useState } from 'react';
import { Card, Spin } from '@douyinfe/semi-ui';
import { API, showError, toBoolean } from '../../helpers';
import { useTranslation } from 'react-i18next';
import SettingModelDeployment from '../../pages/Setting/Model/SettingModelDeployment';
const ModelDeploymentSetting = () => {
const { t } = useTranslation();
let [inputs, setInputs] = useState({
'model_deployment.ionet.api_key': '',
'model_deployment.ionet.enabled': false,
});
let [loading, setLoading] = useState(false);
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {
'model_deployment.ionet.api_key': '',
'model_deployment.ionet.enabled': false,
};
data.forEach((item) => {
if (item.key.endsWith('Enabled') || item.key.endsWith('enabled')) {
newInputs[item.key] = toBoolean(item.value);
} else {
newInputs[item.key] = item.value;
}
});
setInputs(newInputs);
} else {
showError(message);
}
};
async function onRefresh() {
try {
setLoading(true);
await getOptions();
} catch (error) {
showError('刷新失败');
console.error(error);
} finally {
setLoading(false);
}
}
useEffect(() => {
onRefresh();
}, []);
return (
<>
<Spin spinning={loading} size='large'>
<Card style={{ marginTop: '10px' }}>
<SettingModelDeployment options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);
};
export default ModelDeploymentSetting;

View File

@@ -26,6 +26,7 @@ import SettingsSensitiveWords from '../../pages/Setting/Operation/SettingsSensit
import SettingsLog from '../../pages/Setting/Operation/SettingsLog';
import SettingsMonitoring from '../../pages/Setting/Operation/SettingsMonitoring';
import SettingsCreditLimit from '../../pages/Setting/Operation/SettingsCreditLimit';
import SettingsCheckin from '../../pages/Setting/Operation/SettingsCheckin';
import { API, showError, toBoolean } from '../../helpers';
const OperationSetting = () => {
@@ -70,7 +71,10 @@ const OperationSetting = () => {
AutomaticEnableChannelEnabled: false,
AutomaticDisableKeywords: '',
'monitor_setting.auto_test_channel_enabled': false,
'monitor_setting.auto_test_channel_minutes': 10,
'monitor_setting.auto_test_channel_minutes': 10 /* 签到设置 */,
'checkin_setting.enabled': false,
'checkin_setting.min_quota': 1000,
'checkin_setting.max_quota': 10000,
});
let [loading, setLoading] = useState(false);
@@ -140,6 +144,10 @@ const OperationSetting = () => {
<Card style={{ marginTop: '10px' }}>
<SettingsCreditLimit options={inputs} refresh={onRefresh} />
</Card>
{/* 签到设置 */}
<Card style={{ marginTop: '10px' }}>
<SettingsCheckin options={inputs} refresh={onRefresh} />
</Card>
</Spin>
</>
);

View File

@@ -39,6 +39,7 @@ import { useTranslation } from 'react-i18next';
import UserInfoHeader from './personal/components/UserInfoHeader';
import AccountManagement from './personal/cards/AccountManagement';
import NotificationSettings from './personal/cards/NotificationSettings';
import CheckinCalendar from './personal/cards/CheckinCalendar';
import EmailBindModal from './personal/modals/EmailBindModal';
import WeChatBindModal from './personal/modals/WeChatBindModal';
import AccountDeleteModal from './personal/modals/AccountDeleteModal';
@@ -314,10 +315,10 @@ const PersonalSetting = () => {
};
const changePassword = async () => {
if (inputs.original_password === '') {
showError(t('请输入原密码!'));
return;
}
// if (inputs.original_password === '') {
// showError(t('请输入原密码!'));
// return;
// }
if (inputs.set_new_password === '') {
showError(t('请输入新密码!'));
return;
@@ -447,6 +448,18 @@ const PersonalSetting = () => {
{/* 顶部用户信息区域 */}
<UserInfoHeader t={t} userState={userState} />
{/* 签到日历 - 仅在启用时显示 */}
{status?.checkin_enabled && (
<div className='mt-4 md:mt-6'>
<CheckinCalendar
t={t}
status={status}
turnstileEnabled={turnstileEnabled}
turnstileSiteKey={turnstileSiteKey}
/>
</div>
)}
{/* 账户管理和其他设置 */}
<div className='grid grid-cols-1 xl:grid-cols-2 items-start gap-4 md:gap-6 mt-4 md:mt-6'>
{/* 左侧:账户管理设置 */}

View File

@@ -0,0 +1,384 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect, useMemo } from 'react';
import {
Card,
Calendar,
Button,
Typography,
Avatar,
Spin,
Tooltip,
Collapsible,
Modal,
} from '@douyinfe/semi-ui';
import {
CalendarCheck,
Gift,
Check,
ChevronDown,
ChevronUp,
} from 'lucide-react';
import Turnstile from 'react-turnstile';
import { API, showError, showSuccess, renderQuota } from '../../../../helpers';
const CheckinCalendar = ({ t, status, turnstileEnabled, turnstileSiteKey }) => {
const [loading, setLoading] = useState(false);
const [checkinLoading, setCheckinLoading] = useState(false);
const [turnstileModalVisible, setTurnstileModalVisible] = useState(false);
const [turnstileWidgetKey, setTurnstileWidgetKey] = useState(0);
const [checkinData, setCheckinData] = useState({
enabled: false,
stats: {
checked_in_today: false,
total_checkins: 0,
total_quota: 0,
checkin_count: 0,
records: [],
},
});
const [currentMonth, setCurrentMonth] = useState(
new Date().toISOString().slice(0, 7),
);
// 初始加载状态,用于避免折叠状态闪烁
const [initialLoaded, setInitialLoaded] = useState(false);
// 折叠状态null 表示未确定(等待首次加载)
const [isCollapsed, setIsCollapsed] = useState(null);
// 创建日期到额度的映射,方便快速查找
const checkinRecordsMap = useMemo(() => {
const map = {};
const records = checkinData.stats?.records || [];
records.forEach((record) => {
map[record.checkin_date] = record.quota_awarded;
});
return map;
}, [checkinData.stats?.records]);
// 计算本月获得的额度
const monthlyQuota = useMemo(() => {
const records = checkinData.stats?.records || [];
return records.reduce(
(sum, record) => sum + (record.quota_awarded || 0),
0,
);
}, [checkinData.stats?.records]);
// 获取签到状态
const fetchCheckinStatus = async (month) => {
const isFirstLoad = !initialLoaded;
setLoading(true);
try {
const res = await API.get(`/api/user/checkin?month=${month}`);
const { success, data, message } = res.data;
if (success) {
setCheckinData(data);
// 首次加载时,根据签到状态设置折叠状态
if (isFirstLoad) {
setIsCollapsed(data.stats?.checked_in_today ?? false);
setInitialLoaded(true);
}
} else {
showError(message || t('获取签到状态失败'));
if (isFirstLoad) {
setIsCollapsed(false);
setInitialLoaded(true);
}
}
} catch (error) {
showError(t('获取签到状态失败'));
if (isFirstLoad) {
setIsCollapsed(false);
setInitialLoaded(true);
}
} finally {
setLoading(false);
}
};
const postCheckin = async (token) => {
const url = token
? `/api/user/checkin?turnstile=${encodeURIComponent(token)}`
: '/api/user/checkin';
return API.post(url);
};
const shouldTriggerTurnstile = (message) => {
if (!turnstileEnabled) return false;
if (typeof message !== 'string') return true;
return message.includes('Turnstile');
};
const doCheckin = async (token) => {
setCheckinLoading(true);
try {
const res = await postCheckin(token);
const { success, data, message } = res.data;
if (success) {
showSuccess(
t('签到成功!获得') + ' ' + renderQuota(data.quota_awarded),
);
// 刷新签到状态
fetchCheckinStatus(currentMonth);
setTurnstileModalVisible(false);
} else {
if (!token && shouldTriggerTurnstile(message)) {
if (!turnstileSiteKey) {
showError('Turnstile is enabled but site key is empty.');
return;
}
setTurnstileModalVisible(true);
return;
}
if (token && shouldTriggerTurnstile(message)) {
setTurnstileWidgetKey((v) => v + 1);
}
showError(message || t('签到失败'));
}
} catch (error) {
showError(t('签到失败'));
} finally {
setCheckinLoading(false);
}
};
useEffect(() => {
if (status?.checkin_enabled) {
fetchCheckinStatus(currentMonth);
}
}, [status?.checkin_enabled, currentMonth]);
// 如果签到功能未启用,不显示组件
if (!status?.checkin_enabled) {
return null;
}
// 日期渲染函数 - 显示签到状态和获得的额度
const dateRender = (dateString) => {
// Semi Calendar 传入的 dateString 是 Date.toString() 格式
// 需要转换为 YYYY-MM-DD 格式来匹配后端数据
const date = new Date(dateString);
if (isNaN(date.getTime())) {
return null;
}
// 使用本地时间格式化,避免时区问题
const year = date.getFullYear();
const month = String(date.getMonth() + 1).padStart(2, '0');
const day = String(date.getDate()).padStart(2, '0');
const formattedDate = `${year}-${month}-${day}`; // YYYY-MM-DD
const quotaAwarded = checkinRecordsMap[formattedDate];
const isCheckedIn = quotaAwarded !== undefined;
if (isCheckedIn) {
return (
<Tooltip
content={`${t('获得')} ${renderQuota(quotaAwarded)}`}
position='top'
>
<div className='absolute inset-0 flex flex-col items-center justify-center cursor-pointer'>
<div className='w-6 h-6 rounded-full bg-green-500 flex items-center justify-center mb-0.5 shadow-sm'>
<Check size={14} className='text-white' strokeWidth={3} />
</div>
<div className='text-[10px] font-medium text-green-600 dark:text-green-400 leading-none'>
{renderQuota(quotaAwarded)}
</div>
</div>
</Tooltip>
);
}
return null;
};
// 处理月份变化
const handleMonthChange = (date) => {
const month = date.toISOString().slice(0, 7);
setCurrentMonth(month);
};
return (
<Card className='!rounded-2xl'>
<Modal
title='Security Check'
visible={turnstileModalVisible}
footer={null}
centered
onCancel={() => {
setTurnstileModalVisible(false);
setTurnstileWidgetKey((v) => v + 1);
}}
>
<div className='flex justify-center py-2'>
<Turnstile
key={turnstileWidgetKey}
sitekey={turnstileSiteKey}
onVerify={(token) => {
doCheckin(token);
}}
onExpire={() => {
setTurnstileWidgetKey((v) => v + 1);
}}
/>
</div>
</Modal>
{/* 卡片头部 */}
<div className='flex items-center justify-between'>
<div
className='flex items-center flex-1 cursor-pointer'
onClick={() => setIsCollapsed(!isCollapsed)}
>
<Avatar size='small' color='green' className='mr-3 shadow-md'>
<CalendarCheck size={16} />
</Avatar>
<div className='flex-1'>
<div className='flex items-center gap-2'>
<Typography.Text className='text-lg font-medium'>
{t('每日签到')}
</Typography.Text>
{isCollapsed ? (
<ChevronDown size={16} className='text-gray-400' />
) : (
<ChevronUp size={16} className='text-gray-400' />
)}
</div>
<div className='text-xs text-gray-500 dark:text-gray-400'>
{!initialLoaded
? t('正在加载签到状态...')
: checkinData.stats?.checked_in_today
? t('今日已签到,累计签到') +
` ${checkinData.stats?.total_checkins || 0} ` +
t('天')
: t('每日签到可获得随机额度奖励')}
</div>
</div>
</div>
<Button
type='primary'
theme='solid'
icon={<Gift size={16} />}
onClick={() => doCheckin()}
loading={checkinLoading || !initialLoaded}
disabled={!initialLoaded || checkinData.stats?.checked_in_today}
className='!bg-green-600 hover:!bg-green-700'
>
{!initialLoaded
? t('加载中...')
: checkinData.stats?.checked_in_today
? t('今日已签到')
: t('立即签到')}
</Button>
</div>
{/* 可折叠内容 */}
<Collapsible isOpen={isCollapsed === false} keepDOM>
{/* 签到统计 */}
<div className='grid grid-cols-3 gap-3 mb-4 mt-4'>
<div className='text-center p-2.5 bg-slate-50 dark:bg-slate-800 rounded-lg'>
<div className='text-xl font-bold text-green-600'>
{checkinData.stats?.total_checkins || 0}
</div>
<div className='text-xs text-gray-500'>{t('累计签到')}</div>
</div>
<div className='text-center p-2.5 bg-slate-50 dark:bg-slate-800 rounded-lg'>
<div className='text-xl font-bold text-orange-600'>
{renderQuota(monthlyQuota, 6)}
</div>
<div className='text-xs text-gray-500'>{t('本月获得')}</div>
</div>
<div className='text-center p-2.5 bg-slate-50 dark:bg-slate-800 rounded-lg'>
<div className='text-xl font-bold text-blue-600'>
{renderQuota(checkinData.stats?.total_quota || 0, 6)}
</div>
<div className='text-xs text-gray-500'>{t('累计获得')}</div>
</div>
</div>
{/* 签到日历 - 使用更紧凑的样式 */}
<Spin spinning={loading}>
<div className='border rounded-lg overflow-hidden checkin-calendar'>
<style>{`
.checkin-calendar .semi-calendar {
font-size: 13px;
}
.checkin-calendar .semi-calendar-month-header {
padding: 8px 12px;
}
.checkin-calendar .semi-calendar-month-week-row {
height: 28px;
}
.checkin-calendar .semi-calendar-month-week-row th {
font-size: 12px;
padding: 4px 0;
}
.checkin-calendar .semi-calendar-month-grid-row {
height: auto;
}
.checkin-calendar .semi-calendar-month-grid-row td {
height: 56px;
padding: 2px;
}
.checkin-calendar .semi-calendar-month-grid-row-cell {
position: relative;
height: 100%;
}
.checkin-calendar .semi-calendar-month-grid-row-cell-day {
position: absolute;
top: 4px;
left: 50%;
transform: translateX(-50%);
font-size: 12px;
z-index: 1;
}
.checkin-calendar .semi-calendar-month-same {
background: transparent;
}
.checkin-calendar .semi-calendar-month-today .semi-calendar-month-grid-row-cell-day {
background: var(--semi-color-primary);
color: white;border-radius: 50%;
width: 20px;
height: 20px;
display: flex;
align-items: center;
justify-content: center;}
`}</style>
<Calendar
mode='month'
onChange={handleMonthChange}
dateGridRender={(dateString, date) => dateRender(dateString)}
/>
</div>
</Spin>
{/* 签到说明 */}
<div className='mt-3 p-2.5 bg-slate-50 dark:bg-slate-800 rounded-lg'>
<Typography.Text type='tertiary' className='text-xs'>
<ul className='list-disc list-inside space-y-0.5'>
<li>{t('每日签到可获得随机额度奖励')}</li>
<li>{t('签到奖励将直接添加到您的账户余额')}</li>
<li>{t('每日仅可签到一次,请勿重复签到')}</li>
</ul>
</Typography.Text>
</div>
</Collapsible>
</Card>
);
};
export default CheckinCalendar;

View File

@@ -44,7 +44,10 @@ import CodeViewer from '../../../playground/CodeViewer';
import { StatusContext } from '../../../../context/Status';
import { UserContext } from '../../../../context/User';
import { useUserPermissions } from '../../../../hooks/common/useUserPermissions';
import { useSidebar } from '../../../../hooks/common/useSidebar';
import {
mergeAdminConfig,
useSidebar,
} from '../../../../hooks/common/useSidebar';
const NotificationSettings = ({
t,
@@ -82,6 +85,7 @@ const NotificationSettings = ({
enabled: true,
channel: true,
models: true,
deployment: true,
redemption: true,
user: true,
setting: true,
@@ -164,6 +168,7 @@ const NotificationSettings = ({
enabled: true,
channel: true,
models: true,
deployment: true,
redemption: true,
user: true,
setting: true,
@@ -178,14 +183,27 @@ const NotificationSettings = ({
try {
// 获取管理员全局配置
if (statusState?.status?.SidebarModulesAdmin) {
const adminConf = JSON.parse(statusState.status.SidebarModulesAdmin);
setAdminConfig(adminConf);
try {
const adminConf = JSON.parse(
statusState.status.SidebarModulesAdmin,
);
setAdminConfig(mergeAdminConfig(adminConf));
} catch (error) {
setAdminConfig(mergeAdminConfig(null));
}
} else {
setAdminConfig(mergeAdminConfig(null));
}
// 获取用户个人配置
const userRes = await API.get('/api/user/self');
if (userRes.data.success && userRes.data.data.sidebar_modules) {
const userConf = JSON.parse(userRes.data.data.sidebar_modules);
let userConf;
if (typeof userRes.data.data.sidebar_modules === 'string') {
userConf = JSON.parse(userRes.data.data.sidebar_modules);
} else {
userConf = userRes.data.data.sidebar_modules;
}
setSidebarModulesUser(userConf);
}
} catch (error) {
@@ -273,6 +291,11 @@ const NotificationSettings = ({
modules: [
{ key: 'channel', title: t('渠道管理'), description: t('API渠道配置') },
{ key: 'models', title: t('模型管理'), description: t('AI模型配置') },
{
key: 'deployment',
title: t('模型部署'),
description: t('模型部署管理'),
},
{
key: 'redemption',
title: t('兑换码管理'),
@@ -812,7 +835,9 @@ const NotificationSettings = ({
</Typography.Text>
</div>
<Switch
checked={sidebarModulesUser[section.key]?.enabled}
checked={
sidebarModulesUser[section.key]?.enabled !== false
}
onChange={handleSectionChange(section.key)}
size='default'
/>
@@ -835,7 +860,8 @@ const NotificationSettings = ({
>
<Card
className={`!rounded-xl border border-gray-200 hover:border-blue-300 transition-all duration-200 ${
sidebarModulesUser[section.key]?.enabled
sidebarModulesUser[section.key]?.enabled !==
false
? ''
: 'opacity-50'
}`}
@@ -866,7 +892,7 @@ const NotificationSettings = ({
checked={
sidebarModulesUser[section.key]?.[
module.key
]
] !== false
}
onChange={handleModuleChange(
section.key,
@@ -874,8 +900,8 @@ const NotificationSettings = ({
)}
size='default'
disabled={
!sidebarModulesUser[section.key]
?.enabled
sidebarModulesUser[section.key]
?.enabled === false
}
/>
</div>

View File

@@ -39,11 +39,16 @@ import {
showError,
} from '../../../helpers';
import { CHANNEL_OPTIONS } from '../../../constants';
import { IconTreeTriangleDown, IconMore } from '@douyinfe/semi-icons';
import {
IconTreeTriangleDown,
IconMore,
IconAlertTriangle,
} from '@douyinfe/semi-icons';
import { FaRandom } from 'react-icons/fa';
// Render functions
const renderType = (type, channelInfo = undefined, t) => {
const renderType = (type, record = {}, t) => {
const channelInfo = record?.channel_info;
let type2label = new Map();
for (let i = 0; i < CHANNEL_OPTIONS.length; i++) {
type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i];
@@ -67,11 +72,65 @@ const renderType = (type, channelInfo = undefined, t) => {
);
}
return (
const typeTag = (
<Tag color={type2label[type]?.color} shape='circle' prefixIcon={icon}>
{type2label[type]?.label}
</Tag>
);
let ionetMeta = null;
if (record?.other_info) {
try {
const parsed = JSON.parse(record.other_info);
if (parsed && typeof parsed === 'object' && parsed.source === 'ionet') {
ionetMeta = parsed;
}
} catch (error) {
// ignore invalid metadata
}
}
if (!ionetMeta) {
return typeTag;
}
const handleNavigate = (event) => {
event?.stopPropagation?.();
if (!ionetMeta?.deployment_id) {
return;
}
const targetUrl = `/console/deployment?deployment_id=${ionetMeta.deployment_id}`;
window.open(targetUrl, '_blank', 'noopener');
};
return (
<Space spacing={6}>
{typeTag}
<Tooltip
content={
<div className='max-w-xs'>
<div className='text-xs text-gray-600'>{t('来源于 IO.NET 部署')}</div>
{ionetMeta?.deployment_id && (
<div className='text-xs text-gray-500 mt-1'>
{t('部署 ID')}: {ionetMeta.deployment_id}
</div>
)}
</div>
}
>
<span>
<Tag
color='purple'
type='light'
className='cursor-pointer'
onClick={handleNavigate}
>
IO.NET
</Tag>
</span>
</Tooltip>
</Space>
);
};
const renderTagType = (t) => {
@@ -187,6 +246,28 @@ const renderResponseTime = (responseTime, t) => {
}
};
const isRequestPassThroughEnabled = (record) => {
if (!record || record.children !== undefined) {
return false;
}
const settingValue = record.setting;
if (!settingValue) {
return false;
}
if (typeof settingValue === 'object') {
return settingValue.pass_through_body_enabled === true;
}
if (typeof settingValue !== 'string') {
return false;
}
try {
const parsed = JSON.parse(settingValue);
return parsed?.pass_through_body_enabled === true;
} catch (error) {
return false;
}
};
export const getChannelsColumns = ({
t,
COLUMN_KEYS,
@@ -205,6 +286,7 @@ export const getChannelsColumns = ({
refresh,
activePage,
channels,
checkOllamaVersion,
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
}) => {
@@ -219,8 +301,9 @@ export const getChannelsColumns = ({
title: t('名称'),
dataIndex: 'name',
render: (text, record, index) => {
if (record.remark && record.remark.trim() !== '') {
return (
const passThroughEnabled = isRequestPassThroughEnabled(record);
const nameNode =
record.remark && record.remark.trim() !== '' ? (
<Tooltip
content={
<div className='flex flex-col gap-2 max-w-xs'>
@@ -250,9 +333,32 @@ export const getChannelsColumns = ({
>
<span>{text}</span>
</Tooltip>
) : (
<span>{text}</span>
);
if (!passThroughEnabled) {
return nameNode;
}
return text;
return (
<Space spacing={6} align='center'>
{nameNode}
<Tooltip
content={t(
'该渠道已开启请求透传:参数覆写、模型重定向、渠道适配等 NewAPI 内置功能将失效,非最佳实践;如因此产生问题,请勿提交 issue 反馈。',
)}
trigger='hover'
position='topLeft'
>
<span className='inline-flex items-center'>
<IconAlertTriangle
style={{ color: 'var(--semi-color-warning)' }}
/>
</span>
</Tooltip>
</Space>
);
},
},
{
@@ -280,12 +386,7 @@ export const getChannelsColumns = ({
dataIndex: 'type',
render: (text, record, index) => {
if (record.children === undefined) {
if (record.channel_info) {
if (record.channel_info.is_multi_key) {
return <>{renderType(text, record.channel_info, t)}</>;
}
}
return <>{renderType(text, undefined, t)}</>;
return <>{renderType(text, record, t)}</>;
} else {
return <>{renderTagType(t)}</>;
}
@@ -519,6 +620,15 @@ export const getChannelsColumns = ({
},
];
if (record.type === 4) {
moreMenuItems.unshift({
node: 'item',
name: t('测活'),
type: 'tertiary',
onClick: () => checkOllamaVersion(record),
});
}
return (
<Space wrap>
<SplitButtonGroup

View File

@@ -57,6 +57,7 @@ const ChannelsTable = (channelsData) => {
setEditingTag,
copySelectedChannel,
refresh,
checkOllamaVersion,
// Multi-key management
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
@@ -82,6 +83,7 @@ const ChannelsTable = (channelsData) => {
refresh,
activePage,
channels,
checkOllamaVersion,
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
});
@@ -103,6 +105,7 @@ const ChannelsTable = (channelsData) => {
refresh,
activePage,
channels,
checkOllamaVersion,
setShowMultiKeyManageModal,
setCurrentMultiKeyChannel,
]);

View File

@@ -18,6 +18,8 @@ For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Banner } from '@douyinfe/semi-ui';
import { IconAlertTriangle } from '@douyinfe/semi-icons';
import CardPro from '../../common/ui/CardPro';
import ChannelsTable from './ChannelsTable';
import ChannelsActions from './ChannelsActions';
@@ -63,6 +65,22 @@ const ChannelsPage = () => {
/>
{/* Main Content */}
{channelsData.globalPassThroughEnabled ? (
<Banner
type='warning'
closeIcon={null}
icon={
<IconAlertTriangle
size='large'
style={{ color: 'var(--semi-color-warning)' }}
/>
}
description={channelsData.t(
'已开启全局请求透传:参数覆写、模型重定向、渠道适配等 NewAPI 内置功能将失效,非最佳实践;如因此产生问题,请勿提交 issue 反馈。',
)}
style={{ marginBottom: 12 }}
/>
) : null}
<CardPro
type='type3'
tabsArea={<ChannelsTabs {...channelsData} />}

View File

@@ -55,6 +55,7 @@ import {
selectFilter,
} from '../../../../helpers';
import ModelSelectModal from './ModelSelectModal';
import OllamaModelModal from './OllamaModelModal';
import JSONEditor from '../../../common/ui/JSONEditor';
import SecureVerificationModal from '../../../common/modals/SecureVerificationModal';
import ChannelKeyDisplay from '../../../common/ui/ChannelKeyDisplay';
@@ -180,6 +181,7 @@ const EditChannelModal = (props) => {
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
const [modelModalVisible, setModelModalVisible] = useState(false);
const [fetchedModels, setFetchedModels] = useState([]);
const [ollamaModalVisible, setOllamaModalVisible] = useState(false);
const formApiRef = useRef(null);
const [vertexKeys, setVertexKeys] = useState([]);
const [vertexFileList, setVertexFileList] = useState([]);
@@ -214,6 +216,8 @@ const EditChannelModal = (props) => {
return [];
}
}, [inputs.model_mapping]);
const [isIonetChannel, setIsIonetChannel] = useState(false);
const [ionetMetadata, setIonetMetadata] = useState(null);
// 密钥显示状态
const [keyDisplayState, setKeyDisplayState] = useState({
@@ -224,6 +228,21 @@ const EditChannelModal = (props) => {
// 专门的2FA验证状态用于TwoFactorAuthModal
const [show2FAVerifyModal, setShow2FAVerifyModal] = useState(false);
const [verifyCode, setVerifyCode] = useState('');
useEffect(() => {
if (!isEdit) {
setIsIonetChannel(false);
setIonetMetadata(null);
}
}, [isEdit]);
const handleOpenIonetDeployment = () => {
if (!ionetMetadata?.deployment_id) {
return;
}
const targetUrl = `/console/deployment?deployment_id=${ionetMetadata.deployment_id}`;
window.open(targetUrl, '_blank', 'noopener');
};
const [verifyLoading, setVerifyLoading] = useState(false);
// 表单块导航相关状态
@@ -404,7 +423,12 @@ const EditChannelModal = (props) => {
handleInputChange('settings', settingsJson);
};
const isIonetLocked = isIonetChannel && isEdit;
const handleInputChange = (name, value) => {
if (isIonetChannel && isEdit && ['type', 'key', 'base_url'].includes(name)) {
return;
}
if (formApiRef.current) {
formApiRef.current.setValue(name, value);
}
@@ -625,6 +649,25 @@ const EditChannelModal = (props) => {
.map((model) => (model || '').trim())
.filter(Boolean);
initialModelMappingRef.current = data.model_mapping || '';
let parsedIonet = null;
if (data.other_info) {
try {
const maybeMeta = JSON.parse(data.other_info);
if (
maybeMeta &&
typeof maybeMeta === 'object' &&
maybeMeta.source === 'ionet'
) {
parsedIonet = maybeMeta;
}
} catch (error) {
// ignore parse error
}
}
const managedByIonet = !!parsedIonet;
setIsIonetChannel(managedByIonet);
setIonetMetadata(parsedIonet);
// console.log(data);
} else {
showError(message);
@@ -632,7 +675,8 @@ const EditChannelModal = (props) => {
setLoading(false);
};
const fetchUpstreamModelList = async (name) => {
const fetchUpstreamModelList = async (name, options = {}) => {
const silent = !!options.silent;
// if (inputs['type'] !== 1) {
// showError(t('仅支持 OpenAI 接口格式'));
// return;
@@ -683,7 +727,9 @@ const EditChannelModal = (props) => {
if (!err) {
const uniqueModels = Array.from(new Set(models));
setFetchedModels(uniqueModels);
setModelModalVisible(true);
if (!silent) {
setModelModalVisible(true);
}
} else {
showError(t('获取模型列表失败'));
}
@@ -1626,20 +1672,44 @@ const EditChannelModal = (props) => {
</div>
</div>
<Form.Select
field='type'
label={t('类型')}
placeholder={t('请选择渠道类型')}
rules={[{ required: true, message: t('请选择渠道类型') }]}
optionList={channelOptionList}
style={{ width: '100%' }}
filter={selectFilter}
autoClearSearchValue={false}
searchPosition='dropdown'
onSearch={(value) => setChannelSearchValue(value)}
renderOptionItem={renderChannelOption}
onChange={(value) => handleInputChange('type', value)}
/>
{isIonetChannel && (
<Banner
type='info'
closeIcon={null}
className='mb-4 rounded-xl'
description={t('此渠道由 IO.NET 自动同步,类型、密钥和 API 地址已锁定。')}
>
<Space>
{ionetMetadata?.deployment_id && (
<Button
size='small'
theme='light'
type='primary'
icon={<IconGlobe />}
onClick={handleOpenIonetDeployment}
>
{t('查看关联部署')}
</Button>
)}
</Space>
</Banner>
)}
<Form.Select
field='type'
label={t('类型')}
placeholder={t('请选择渠道类型')}
rules={[{ required: true, message: t('请选择渠道类型') }]}
optionList={channelOptionList}
style={{ width: '100%' }}
filter={selectFilter}
autoClearSearchValue={false}
searchPosition='dropdown'
onSearch={(value) => setChannelSearchValue(value)}
renderOptionItem={renderChannelOption}
onChange={(value) => handleInputChange('type', value)}
disabled={isIonetLocked}
/>
{inputs.type === 20 && (
<Form.Switch
@@ -1778,87 +1848,86 @@ const EditChannelModal = (props) => {
autosize
autoComplete='new-password'
onChange={(value) => handleInputChange('key', value)}
extraText={
<div className='flex items-center gap-2 flex-wrap'>
{isEdit &&
isMultiKeyChannel &&
keyMode === 'append' && (
<Text type='warning' size='small'>
{t(
'追加模式:新密钥将添加到现有密钥列表的末尾',
)}
</Text>
)}
{isEdit && (
disabled={isIonetLocked}
extraText={
<div className='flex items-center gap-2 flex-wrap'>
{isEdit &&
isMultiKeyChannel &&
keyMode === 'append' && (
<Text type='warning' size='small'>
{t(
'追加模式:新密钥将添加到现有密钥列表的末尾',
)}
</Text>
)}
{isEdit && (
<Button
size='small'
type='primary'
theme='outline'
onClick={handleShow2FAModal}
>
{t('查看密钥')}
</Button>
)}
{batchExtra}
</div>
}
showClear
/>
)
) : (
<>
{inputs.type === 41 &&
(inputs.vertex_key_type || 'json') === 'json' ? (
<>
{!batch && (
<div className='flex items-center justify-between mb-3'>
<Text className='text-sm font-medium'>
{t('密钥输入方式')}
</Text>
<Space>
<Button
size='small'
type='primary'
theme='outline'
onClick={handleShow2FAModal}
type={
!useManualInput ? 'primary' : 'tertiary'
}
onClick={() => {
setUseManualInput(false);
// 切换到文件上传模式时清空手动输入的密钥
if (formApiRef.current) {
formApiRef.current.setValue('key', '');
}
handleInputChange('key', '');
}}
>
{t('查看密钥')}
{t('文件上传')}
</Button>
)}
{batchExtra}
<Button
size='small'
type={useManualInput ? 'primary' : 'tertiary'}
onClick={() => {
setUseManualInput(true);
// 切换到手动输入模式时清空文件上传相关状态
setVertexKeys([]);
setVertexFileList([]);
if (formApiRef.current) {
formApiRef.current.setValue(
'vertex_files',
[],
);
}
setInputs((prev) => ({
...prev,
vertex_files: [],
}));
}}
>
{t('手动输入')}
</Button>
</Space>
</div>
}
showClear
/>
)
) : (
<>
{inputs.type === 41 &&
(inputs.vertex_key_type || 'json') === 'json' ? (
<>
{!batch && (
<div className='flex items-center justify-between mb-3'>
<Text className='text-sm font-medium'>
{t('密钥输入方式')}
</Text>
<Space>
<Button
size='small'
type={
!useManualInput ? 'primary' : 'tertiary'
}
onClick={() => {
setUseManualInput(false);
// 切换到文件上传模式时清空手动输入的密钥
if (formApiRef.current) {
formApiRef.current.setValue('key', '');
}
handleInputChange('key', '');
}}
>
{t('文件上传')}
</Button>
<Button
size='small'
type={
useManualInput ? 'primary' : 'tertiary'
}
onClick={() => {
setUseManualInput(true);
// 切换到手动输入模式时清空文件上传相关状态
setVertexKeys([]);
setVertexFileList([]);
if (formApiRef.current) {
formApiRef.current.setValue(
'vertex_files',
[],
);
}
setInputs((prev) => ({
...prev,
vertex_files: [],
}));
}}
>
{t('手动输入')}
</Button>
</Space>
</div>
)}
)}
{batch && (
<Banner
@@ -2189,84 +2258,86 @@ const EditChannelModal = (props) => {
/>
)}
{inputs.type === 3 && (
<>
<Banner
type='warning'
description={t(
'2025年5月10日后添加的渠道不需要再在部署的时候移除模型名称中的"."',
{inputs.type === 3 && (
<>
<Banner
type='warning'
description={t(
'2025年5月10日后添加的渠道不需要再在部署的时候移除模型名称中的"."',
)}
className='!rounded-lg'
/>
<div>
<Form.Input
field='base_url'
label='AZURE_OPENAI_ENDPOINT'
placeholder={t(
'请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com',
)}
className='!rounded-lg'
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
disabled={isIonetLocked}
/>
<div>
<Form.Input
field='base_url'
label='AZURE_OPENAI_ENDPOINT'
placeholder={t(
'请输入 AZURE_OPENAI_ENDPOINT例如https://docs-test-001.openai.azure.com',
)}
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
/>
</div>
<div>
<Form.Input
field='other'
label={t('默认 API 版本')}
placeholder={t(
'请输入默认 API 版本例如2025-04-01-preview',
)}
onChange={(value) =>
handleInputChange('other', value)
}
showClear
/>
</div>
<div>
<Form.Input
field='azure_responses_version'
label={t(
'默认 Responses API 版本,为空则使用上方版本',
)}
placeholder={t('例如preview')}
onChange={(value) =>
handleChannelOtherSettingsChange(
'azure_responses_version',
value,
)
}
showClear
/>
</div>
</>
)}
</div>
<div>
<Form.Input
field='other'
label={t('默认 API 版本')}
placeholder={t(
'请输入默认 API 版本例如2025-04-01-preview',
)}
onChange={(value) =>
handleInputChange('other', value)
}
showClear
/>
</div>
<div>
<Form.Input
field='azure_responses_version'
label={t(
'默认 Responses API 版本,为空则使用上方版本',
)}
placeholder={t('例如preview')}
onChange={(value) =>
handleChannelOtherSettingsChange(
'azure_responses_version',
value,
)
}
showClear
/>
</div>
</>
)}
{inputs.type === 8 && (
<>
<Banner
type='warning'
description={t(
'如果你对接的是上游One API或者New API等转发项目请使用OpenAI类型不要使用此类型除非你知道你在做什么。',
{inputs.type === 8 && (
<>
<Banner
type='warning'
description={t(
'如果你对接的是上游One API或者New API等转发项目请使用OpenAI类型不要使用此类型除非你知道你在做什么。',
)}
className='!rounded-lg'
/>
<div>
<Form.Input
field='base_url'
label={t('完整的 Base URL支持变量{model}')}
placeholder={t(
'请输入完整的URL例如https://api.openai.com/v1/chat/completions',
)}
className='!rounded-lg'
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
disabled={isIonetLocked}
/>
<div>
<Form.Input
field='base_url'
label={t('完整的 Base URL支持变量{model}')}
placeholder={t(
'请输入完整的URL例如https://api.openai.com/v1/chat/completions',
)}
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
/>
</div>
</>
)}
</div>
</>
)}
{inputs.type === 37 && (
<Banner
@@ -2294,76 +2365,77 @@ const EditChannelModal = (props) => {
handleInputChange('base_url', value)
}
showClear
extraText={t(
'对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写',
)}
/>
</div>
)}
{inputs.type === 22 && (
<div>
<Form.Input
field='base_url'
label={t('私有部署地址')}
placeholder={t(
'请输入私有部署地址格式为https://fastgpt.run/api/openapi',
disabled={isIonetLocked}
extraText={t(
'对于官方渠道new-api已经内置地址除非是第三方代理站点或者Azure的特殊接入地址否则不需要填写',
)}
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
/>
</div>
)}
{inputs.type === 36 && (
<div>
<Form.Input
field='base_url'
label={t(
'注意非Chat API请务必填写正确的API地址否则可能导致无法使用',
)}
placeholder={t(
'请输入到 /suno 前的路径通常就是域名例如https://api.example.com',
)}
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
/>
</div>
)}
{inputs.type === 22 && (
<div>
<Form.Input
field='base_url'
label={t('私有部署地址')}
placeholder={t(
'请输入私有部署地址格式为https://fastgpt.run/api/openapi',
)}
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
disabled={isIonetLocked}
/>
</div>
)}
{inputs.type === 45 && !doubaoApiEditUnlocked && (
<div>
<Form.Select
field='base_url'
label={t('API地址')}
placeholder={t('请选择API地址')}
onChange={(value) =>
{inputs.type === 36 && (
<div>
<Form.Input
field='base_url'
label={t(
'注意非Chat API请务必填写正确的API地址否则可能导致无法使用',
)}
placeholder={t(
'请输入到 /suno 前的路径通常就是域名例如https://api.example.com',
)}
onChange={(value) =>
handleInputChange('base_url', value)
}
showClear
disabled={isIonetLocked}
/>
</div>
)}
{inputs.type === 45 && !doubaoApiEditUnlocked && (
<div>
<Form.Select
field='base_url'
label={t('API地址')}
placeholder={t('请选择API地址')}
onChange={(value) =>
handleInputChange('base_url', value)
}
optionList={[
{
value: 'https://ark.cn-beijing.volces.com',
label: 'https://ark.cn-beijing.volces.com',
},
{
value:
'https://ark.ap-southeast.bytepluses.com',
label:
'https://ark.ap-southeast.bytepluses.com',
},
{
value: 'doubao-coding-plan',
}
optionList={[
{
value: 'https://ark.cn-beijing.volces.com',
label: 'https://ark.cn-beijing.volces.com',
},
{
value: 'https://ark.ap-southeast.bytepluses.com',
label: 'https://ark.ap-southeast.bytepluses.com',
},
{
value: 'doubao-coding-plan',
label: 'Doubao Coding Plan',
},
]}
defaultValue='https://ark.cn-beijing.volces.com'
/>
</div>
)}
]}defaultValue='https://ark.cn-beijing.volces.com'
disabled={isIonetLocked}
/>
</div>
)}
</Card>
</div>
)}
@@ -2458,72 +2530,80 @@ const EditChannelModal = (props) => {
{t('获取模型列表')}
</Button>
)}
{inputs.type === 4 && isEdit && (
<Button
size='small'
type='warning'
onClick={() => handleInputChange('models', [])}
type='primary'
theme='light'
onClick={() => setOllamaModalVisible(true)}
>
{t('清除所有模型')}
{t('Ollama 模型管理')}
</Button>
<Button
size='small'
type='tertiary'
onClick={() => {
if (inputs.models.length === 0) {
showInfo(t('没有模型可以复制'));
return;
}
try {
copy(inputs.models.join(','));
showSuccess(t('模型列表已复制到剪贴板'));
} catch (error) {
showError(t('复制失败'));
}
}}
>
{t('复制所有模型')}
</Button>
{modelGroups &&
modelGroups.length > 0 &&
modelGroups.map((group) => (
<Button
key={group.id}
size='small'
type='primary'
onClick={() => {
let items = [];
try {
if (Array.isArray(group.items)) {
items = group.items;
} else if (
typeof group.items === 'string'
) {
const parsed = JSON.parse(
group.items || '[]',
);
if (Array.isArray(parsed)) items = parsed;
}
} catch {}
const current =
formApiRef.current?.getValue('models') ||
inputs.models ||
[];
const merged = Array.from(
new Set(
[...current, ...items]
.map((m) => (m || '').trim())
.filter(Boolean),
),
);
handleInputChange('models', merged);
}}
>
{group.name}
</Button>
))}
</Space>
}
/>
)}
<Button
size='small'
type='warning'
onClick={() => handleInputChange('models', [])}
>
{t('清除所有模型')}
</Button>
<Button
size='small'
type='tertiary'
onClick={() => {
if (inputs.models.length === 0) {
showInfo(t('没有模型可以复制'));
return;
}
try {
copy(inputs.models.join(','));
showSuccess(t('模型列表已复制到剪贴板'));
} catch (error) {
showError(t('复制失败'));
}
}}
>
{t('复制所有模型')}
</Button>
{modelGroups &&
modelGroups.length > 0 &&
modelGroups.map((group) => (
<Button
key={group.id}
size='small'
type='primary'
onClick={() => {
let items = [];
try {
if (Array.isArray(group.items)) {
items = group.items;
} else if (typeof group.items === 'string') {
const parsed = JSON.parse(
group.items || '[]',
);
if (Array.isArray(parsed)) items = parsed;
}
} catch {}
const current =
formApiRef.current?.getValue('models') ||
inputs.models ||
[];
const merged = Array.from(
new Set(
[...current, ...items]
.map((m) => (m || '').trim())
.filter(Boolean),
),
);
handleInputChange('models', merged);
}}
>
{group.name}
</Button>
))}
</Space>
}
/>
<Form.Input
field='custom_model'
@@ -3083,6 +3163,33 @@ const EditChannelModal = (props) => {
}}
onCancel={() => setModelModalVisible(false)}
/>
<OllamaModelModal
visible={ollamaModalVisible}
onCancel={() => setOllamaModalVisible(false)}
channelId={channelId}
channelInfo={inputs}
onModelsUpdate={(options = {}) => {
// 当模型更新后,重新获取模型列表以更新表单
fetchUpstreamModelList('models', { silent: !!options.silent });
}}
onApplyModels={({ mode, modelIds } = {}) => {
if (!Array.isArray(modelIds) || modelIds.length === 0) {
return;
}
const existingModels = Array.isArray(inputs.models)
? inputs.models.map(String)
: [];
const incoming = modelIds.map(String);
const nextModels = Array.from(new Set([...existingModels, ...incoming]));
handleInputChange('models', nextModels);
if (formApiRef.current) {
formApiRef.current.setValue('models', nextModels);
}
showSuccess(t('模型列表已追加更新'));
}}
/>
</>
);
};

View File

@@ -47,7 +47,20 @@ const ModelSelectModal = ({
onCancel,
}) => {
const { t } = useTranslation();
const [checkedList, setCheckedList] = useState(selected);
const getModelName = (model) => {
if (!model) return '';
if (typeof model === 'string') return model;
if (typeof model === 'object' && model.model_name) return model.model_name;
return String(model ?? '');
};
const normalizedSelected = useMemo(
() => (selected || []).map(getModelName),
[selected],
);
const [checkedList, setCheckedList] = useState(normalizedSelected);
const [keyword, setKeyword] = useState('');
const [activeTab, setActiveTab] = useState('new');
@@ -105,9 +118,9 @@ const ModelSelectModal = ({
// 同步外部选中值
useEffect(() => {
if (visible) {
setCheckedList(selected);
setCheckedList(normalizedSelected);
}
}, [visible, selected]);
}, [visible, normalizedSelected]);
// 当模型列表变化时设置默认tab
useEffect(() => {

View File

@@ -265,6 +265,11 @@ const ModelTestModal = ({
placeholder={t('选择端点类型')}
/>
</div>
<Typography.Text type='tertiary' size='small' className='block mb-2'>
{t(
'说明:本页测试为非流式请求;若渠道仅支持流式返回,可能出现测试失败,请以实际使用为准。',
)}
</Typography.Text>
{/* 搜索与操作按钮 */}
<div className='flex items-center justify-end gap-2 w-full mb-2'>

View File

@@ -0,0 +1,778 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import {
Modal,
Button,
Typography,
Card,
List,
Space,
Input,
Spin,
Popconfirm,
Tag,
Empty,
Row,
Col,
Progress,
Checkbox,
} from '@douyinfe/semi-ui';
import {
IconDownload,
IconDelete,
IconRefresh,
IconSearch,
IconPlus,
} from '@douyinfe/semi-icons';
import {
API,
authHeader,
getUserIdFromLocalStorage,
showError,
showSuccess,
} from '../../../../helpers';
const { Text, Title } = Typography;
const CHANNEL_TYPE_OLLAMA = 4;
const parseMaybeJSON = (value) => {
if (!value) return null;
if (typeof value === 'object') return value;
if (typeof value === 'string') {
try {
return JSON.parse(value);
} catch (error) {
return null;
}
}
return null;
};
const resolveOllamaBaseUrl = (info) => {
if (!info) {
return '';
}
const direct = typeof info.base_url === 'string' ? info.base_url.trim() : '';
if (direct) {
return direct;
}
const alt =
typeof info.ollama_base_url === 'string' ? info.ollama_base_url.trim() : '';
if (alt) {
return alt;
}
const parsed = parseMaybeJSON(info.other_info);
if (parsed && typeof parsed === 'object') {
const candidate =
(typeof parsed.base_url === 'string' && parsed.base_url.trim()) ||
(typeof parsed.public_url === 'string' && parsed.public_url.trim()) ||
(typeof parsed.api_url === 'string' && parsed.api_url.trim());
if (candidate) {
return candidate;
}
}
return '';
};
const normalizeModels = (items) => {
if (!Array.isArray(items)) {
return [];
}
return items
.map((item) => {
if (!item) {
return null;
}
if (typeof item === 'string') {
return {
id: item,
owned_by: 'ollama',
};
}
if (typeof item === 'object') {
const candidateId =
item.id || item.ID || item.name || item.model || item.Model;
if (!candidateId) {
return null;
}
const metadata = item.metadata || item.Metadata;
const normalized = {
...item,
id: candidateId,
owned_by: item.owned_by || item.ownedBy || 'ollama',
};
if (typeof item.size === 'number' && !normalized.size) {
normalized.size = item.size;
}
if (metadata && typeof metadata === 'object') {
if (typeof metadata.size === 'number' && !normalized.size) {
normalized.size = metadata.size;
}
if (!normalized.digest && typeof metadata.digest === 'string') {
normalized.digest = metadata.digest;
}
if (
!normalized.modified_at &&
typeof metadata.modified_at === 'string'
) {
normalized.modified_at = metadata.modified_at;
}
if (metadata.details && !normalized.details) {
normalized.details = metadata.details;
}
}
return normalized;
}
return null;
})
.filter(Boolean);
};
const OllamaModelModal = ({
visible,
onCancel,
channelId,
channelInfo,
onModelsUpdate,
onApplyModels,
}) => {
const { t } = useTranslation();
const [loading, setLoading] = useState(false);
const [models, setModels] = useState([]);
const [filteredModels, setFilteredModels] = useState([]);
const [searchValue, setSearchValue] = useState('');
const [pullModelName, setPullModelName] = useState('');
const [pullLoading, setPullLoading] = useState(false);
const [pullProgress, setPullProgress] = useState(null);
const [eventSource, setEventSource] = useState(null);
const [selectedModelIds, setSelectedModelIds] = useState([]);
const handleApplyAllModels = () => {
if (!onApplyModels || selectedModelIds.length === 0) {
return;
}
onApplyModels({ mode: 'append', modelIds: selectedModelIds });
};
const handleToggleModel = (modelId, checked) => {
if (!modelId) {
return;
}
setSelectedModelIds((prev) => {
if (checked) {
if (prev.includes(modelId)) {
return prev;
}
return [...prev, modelId];
}
return prev.filter((id) => id !== modelId);
});
};
const handleSelectAll = () => {
setSelectedModelIds(models.map((item) => item?.id).filter(Boolean));
};
const handleClearSelection = () => {
setSelectedModelIds([]);
};
// 获取模型列表
const fetchModels = async () => {
const channelType = Number(channelInfo?.type ?? CHANNEL_TYPE_OLLAMA);
const shouldTryLiveFetch = channelType === CHANNEL_TYPE_OLLAMA;
const resolvedBaseUrl = resolveOllamaBaseUrl(channelInfo);
setLoading(true);
let liveFetchSucceeded = false;
let fallbackSucceeded = false;
let lastError = '';
let nextModels = [];
try {
if (shouldTryLiveFetch && resolvedBaseUrl) {
try {
const payload = {
base_url: resolvedBaseUrl,
type: CHANNEL_TYPE_OLLAMA,
key: channelInfo?.key || '',
};
const res = await API.post('/api/channel/fetch_models', payload, {
skipErrorHandler: true,
});
if (res?.data?.success) {
nextModels = normalizeModels(res.data.data);
liveFetchSucceeded = true;
} else if (res?.data?.message) {
lastError = res.data.message;
}
} catch (error) {
const message = error?.response?.data?.message || error.message;
if (message) {
lastError = message;
}
}
} else if (shouldTryLiveFetch && !resolvedBaseUrl && !channelId) {
lastError = t('请先填写 Ollama API 地址');
}
if ((!liveFetchSucceeded || nextModels.length === 0) && channelId) {
try {
const res = await API.get(`/api/channel/fetch_models/${channelId}`, {
skipErrorHandler: true,
});
if (res?.data?.success) {
nextModels = normalizeModels(res.data.data);
fallbackSucceeded = true;
lastError = '';
} else if (res?.data?.message) {
lastError = res.data.message;
}
} catch (error) {
const message = error?.response?.data?.message || error.message;
if (message) {
lastError = message;
}
}
}
if (!liveFetchSucceeded && !fallbackSucceeded && lastError) {
showError(`${t('获取模型列表失败')}: ${lastError}`);
}
const normalized = nextModels;
setModels(normalized);
setFilteredModels(normalized);
setSelectedModelIds((prev) => {
if (!normalized || normalized.length === 0) {
return [];
}
if (!prev || prev.length === 0) {
return normalized.map((item) => item.id).filter(Boolean);
}
const available = prev.filter((id) =>
normalized.some((item) => item.id === id),
);
return available.length > 0
? available
: normalized.map((item) => item.id).filter(Boolean);
});
} finally {
setLoading(false);
}
};
// 拉取模型 (流式,支持进度)
const pullModel = async () => {
if (!pullModelName.trim()) {
showError(t('请输入模型名称'));
return;
}
setPullLoading(true);
setPullProgress({ status: 'starting', completed: 0, total: 0 });
let hasRefreshed = false;
const refreshModels = async () => {
if (hasRefreshed) return;
hasRefreshed = true;
await fetchModels();
if (onModelsUpdate) {
onModelsUpdate({ silent: true });
}
};
try {
// 关闭之前的连接
if (eventSource) {
eventSource.close();
setEventSource(null);
}
const controller = new AbortController();
const closable = {
close: () => controller.abort(),
};
setEventSource(closable);
// 使用 fetch 请求 SSE 流
const authHeaders = authHeader();
const userId = getUserIdFromLocalStorage();
const fetchHeaders = {
'Content-Type': 'application/json',
Accept: 'text/event-stream',
'New-API-User': String(userId),
...authHeaders,
};
const response = await fetch('/api/channel/ollama/pull/stream', {
method: 'POST',
headers: fetchHeaders,
body: JSON.stringify({
channel_id: channelId,
model_name: pullModelName.trim(),
}),
signal: controller.signal,
});
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
// 读取 SSE 流
const processStream = async () => {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || '';
for (const line of lines) {
if (!line.startsWith('data: ')) {
continue;
}
try {
const eventData = line.substring(6);
if (eventData === '[DONE]') {
setPullLoading(false);
setPullProgress(null);
setEventSource(null);
return;
}
const data = JSON.parse(eventData);
if (data.status) {
// 处理进度数据
setPullProgress(data);
} else if (data.error) {
// 处理错误
showError(data.error);
setPullProgress(null);
setPullLoading(false);
setEventSource(null);
return;
} else if (data.message) {
// 处理成功消息
showSuccess(data.message);
setPullModelName('');
setPullProgress(null);
setPullLoading(false);
setEventSource(null);
await fetchModels();
if (onModelsUpdate) {
onModelsUpdate({ silent: true });
}
await refreshModels();
return;
}
} catch (e) {
console.error('Failed to parse SSE data:', e);
}
}
}
// 正常结束流
setPullLoading(false);
setPullProgress(null);
setEventSource(null);
await refreshModels();
} catch (error) {
if (error?.name === 'AbortError') {
setPullProgress(null);
setPullLoading(false);
setEventSource(null);
return;
}
console.error('Stream processing error:', error);
showError(t('数据传输中断'));
setPullProgress(null);
setPullLoading(false);
setEventSource(null);
await refreshModels();
}
};
await processStream();
} catch (error) {
if (error?.name !== 'AbortError') {
showError(t('模型拉取失败: {{error}}', { error: error.message }));
}
setPullLoading(false);
setPullProgress(null);
setEventSource(null);
await refreshModels();
}
};
// 删除模型
const deleteModel = async (modelName) => {
try {
const res = await API.delete('/api/channel/ollama/delete', {
data: {
channel_id: channelId,
model_name: modelName,
},
});
if (res.data.success) {
showSuccess(t('模型删除成功'));
await fetchModels(); // 重新获取模型列表
if (onModelsUpdate) {
onModelsUpdate({ silent: true }); // 通知父组件更新
}
} else {
showError(res.data.message || t('模型删除失败'));
}
} catch (error) {
showError(t('模型删除失败: {{error}}', { error: error.message }));
}
};
// 搜索过滤
useEffect(() => {
if (!searchValue) {
setFilteredModels(models);
} else {
const filtered = models.filter((model) =>
model.id.toLowerCase().includes(searchValue.toLowerCase()),
);
setFilteredModels(filtered);
}
}, [models, searchValue]);
useEffect(() => {
if (!visible) {
setSelectedModelIds([]);
setPullModelName('');
setPullProgress(null);
setPullLoading(false);
}
}, [visible]);
// 组件加载时获取模型列表
useEffect(() => {
if (!visible) {
return;
}
if (channelId || Number(channelInfo?.type) === CHANNEL_TYPE_OLLAMA) {
fetchModels();
}
}, [
visible,
channelId,
channelInfo?.type,
channelInfo?.base_url,
channelInfo?.other_info,
channelInfo?.ollama_base_url,
]);
// 组件卸载时清理 EventSource
useEffect(() => {
return () => {
if (eventSource) {
eventSource.close();
}
};
}, [eventSource]);
const formatModelSize = (size) => {
if (!size) return '-';
const gb = size / (1024 * 1024 * 1024);
return gb >= 1
? `${gb.toFixed(1)} GB`
: `${(size / (1024 * 1024)).toFixed(0)} MB`;
};
return (
<Modal
title={t('Ollama 模型管理')}
visible={visible}
onCancel={onCancel}
width={720}
style={{ maxWidth: '95vw' }}
footer={
<Button theme='solid' type='primary' onClick={onCancel}>
{t('关闭')}
</Button>
}
>
<Space vertical spacing='medium' style={{ width: '100%' }}>
<div>
<Text type='tertiary' size='small'>
{channelInfo?.name ? `${channelInfo.name} - ` : ''}
{t('管理 Ollama 模型的拉取和删除')}
</Text>
</div>
{/* 拉取新模型 */}
<Card>
<Title heading={6} className='m-0 mb-3'>
{t('拉取新模型')}
</Title>
<Row gutter={12} align='middle'>
<Col span={16}>
<Input
placeholder={t('请输入模型名称,例如: llama3.2, qwen2.5:7b')}
value={pullModelName}
onChange={(value) => setPullModelName(value)}
onEnterPress={pullModel}
disabled={pullLoading}
showClear
/>
</Col>
<Col span={8}>
<Button
theme='solid'
type='primary'
onClick={pullModel}
loading={pullLoading}
disabled={!pullModelName.trim()}
icon={<IconDownload />}
block
>
{pullLoading ? t('拉取中...') : t('拉取模型')}
</Button>
</Col>
</Row>
{/* 进度条显示 */}
{pullProgress &&
(() => {
const completedBytes = Number(pullProgress.completed) || 0;
const totalBytes = Number(pullProgress.total) || 0;
const hasTotal = Number.isFinite(totalBytes) && totalBytes > 0;
const safePercent = hasTotal
? Math.min(
100,
Math.max(
0,
Math.round((completedBytes / totalBytes) * 100),
),
)
: null;
const percentText =
hasTotal && safePercent !== null
? `${safePercent.toFixed(0)}%`
: pullProgress.status || t('处理中');
return (
<div style={{ marginTop: 12 }}>
<div className='flex items-center justify-between mb-2'>
<Text strong>{t('拉取进度')}</Text>
<Text type='tertiary' size='small'>
{percentText}
</Text>
</div>
{hasTotal && safePercent !== null ? (
<div>
<Progress
percent={safePercent}
showInfo={false}
stroke='#1890ff'
size='small'
/>
<div className='flex justify-between mt-1'>
<Text type='tertiary' size='small'>
{(completedBytes / (1024 * 1024 * 1024)).toFixed(2)}{' '}
GB
</Text>
<Text type='tertiary' size='small'>
{(totalBytes / (1024 * 1024 * 1024)).toFixed(2)} GB
</Text>
</div>
</div>
) : (
<div className='flex items-center gap-2 text-xs text-[var(--semi-color-text-2)]'>
<Spin size='small' />
<span>{t('准备中...')}</span>
</div>
)}
</div>
);
})()}
<Text type='tertiary' size='small' className='mt-2 block'>
{t(
'支持拉取 Ollama 官方模型库中的所有模型,拉取过程可能需要几分钟时间',
)}
</Text>
</Card>
{/* 已有模型列表 */}
<Card>
<div className='flex items-center justify-between mb-3'>
<div className='flex items-center gap-2'>
<Title heading={6} className='m-0'>
{t('已有模型')}
</Title>
{models.length > 0 ? (
<Tag color='blue'>{models.length}</Tag>
) : null}
</div>
<Space wrap>
<Input
prefix={<IconSearch />}
placeholder={t('搜索模型...')}
value={searchValue}
onChange={(value) => setSearchValue(value)}
style={{ width: 200 }}
showClear
/>
<Button
size='small'
theme='light'
onClick={handleSelectAll}
disabled={models.length === 0}
>
{t('全选')}
</Button>
<Button
size='small'
theme='light'
onClick={handleClearSelection}
disabled={selectedModelIds.length === 0}
>
{t('清空')}
</Button>
<Button
theme='solid'
type='primary'
icon={<IconPlus />}
onClick={handleApplyAllModels}
disabled={selectedModelIds.length === 0}
size='small'
>
{t('加入渠道')}
</Button>
<Button
theme='light'
type='primary'
onClick={fetchModels}
loading={loading}
icon={<IconRefresh />}
size='small'
>
{t('刷新')}
</Button>
</Space>
</div>
<Spin spinning={loading}>
{filteredModels.length === 0 ? (
<Empty
title={searchValue ? t('未找到匹配的模型') : t('暂无模型')}
description={
searchValue
? t('请尝试其他搜索关键词')
: t('您可以在上方拉取需要的模型')
}
style={{ padding: '40px 0' }}
/>
) : (
<List
dataSource={filteredModels}
split
renderItem={(model) => (
<List.Item key={model.id}>
<div className='flex items-center justify-between w-full'>
<div className='flex items-center flex-1 min-w-0 gap-3'>
<Checkbox
checked={selectedModelIds.includes(model.id)}
onChange={(checked) =>
handleToggleModel(model.id, checked)
}
/>
<div className='flex-1 min-w-0'>
<Text strong className='block truncate'>
{model.id}
</Text>
<div className='flex items-center space-x-2 mt-1'>
<Tag color='cyan' size='small'>
{model.owned_by || 'ollama'}
</Tag>
{model.size && (
<Text type='tertiary' size='small'>
{formatModelSize(model.size)}
</Text>
)}
</div>
</div>
</div>
<div className='flex items-center space-x-2 ml-4'>
<Popconfirm
title={t('确认删除模型')}
content={t(
'删除后无法恢复,确定要删除模型 "{{name}}" 吗?',
{ name: model.id },
)}
onConfirm={() => deleteModel(model.id)}
okText={t('确认')}
cancelText={t('取消')}
>
<Button
theme='borderless'
type='danger'
size='small'
icon={<IconDelete />}
/>
</Popconfirm>
</div>
</div>
</List.Item>
)}
/>
)}
</Spin>
</Card>
</Space>
</Modal>
);
};
export default OllamaModelModal;

View File

@@ -0,0 +1,109 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Button, Popconfirm } from '@douyinfe/semi-ui';
import CompactModeToggle from '../../common/ui/CompactModeToggle';
const DeploymentsActions = ({
selectedKeys,
setSelectedKeys,
setEditingDeployment,
setShowEdit,
batchDeleteDeployments,
batchOperationsEnabled = true,
compactMode,
setCompactMode,
showCreateModal,
setShowCreateModal,
t,
}) => {
const hasSelected = batchOperationsEnabled && selectedKeys.length > 0;
const handleAddDeployment = () => {
if (setShowCreateModal) {
setShowCreateModal(true);
} else {
// Fallback to old behavior if setShowCreateModal is not provided
setEditingDeployment({ id: undefined });
setShowEdit(true);
}
};
const handleBatchDelete = () => {
batchDeleteDeployments();
};
const handleDeselectAll = () => {
setSelectedKeys([]);
};
return (
<div className='flex flex-wrap gap-2 w-full md:w-auto order-2 md:order-1'>
<Button
type='primary'
className='flex-1 md:flex-initial'
onClick={handleAddDeployment}
size='small'
>
{t('新建容器')}
</Button>
{hasSelected && (
<>
<Popconfirm
title={t('确认删除')}
content={`${t('确定要删除选中的')} ${selectedKeys.length} ${t('个部署吗?此操作不可逆。')}`}
okText={t('删除')}
cancelText={t('取消')}
okType='danger'
onConfirm={handleBatchDelete}
>
<Button
type='danger'
className='flex-1 md:flex-initial'
disabled={selectedKeys.length === 0}
size='small'
>
{t('批量删除')} ({selectedKeys.length})
</Button>
</Popconfirm>
<Button
type='tertiary'
className='flex-1 md:flex-initial'
onClick={handleDeselectAll}
size='small'
>
{t('取消选择')}
</Button>
</>
)}
{/* Compact Mode */}
<CompactModeToggle
compactMode={compactMode}
setCompactMode={setCompactMode}
t={t}
/>
</div>
);
};
export default DeploymentsActions;

View File

@@ -0,0 +1,702 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React from 'react';
import { Button, Dropdown, Tag, Typography } from '@douyinfe/semi-ui';
import { timestamp2string, showSuccess, showError } from '../../../helpers';
import { IconMore } from '@douyinfe/semi-icons';
import {
FaPlay,
FaTrash,
FaServer,
FaMemory,
FaMicrochip,
FaCheckCircle,
FaSpinner,
FaClock,
FaExclamationCircle,
FaBan,
FaTerminal,
FaPlus,
FaCog,
FaInfoCircle,
FaLink,
FaStop,
FaHourglassHalf,
FaGlobe,
} from 'react-icons/fa';
const normalizeStatus = (status) =>
typeof status === 'string' ? status.trim().toLowerCase() : '';
const STATUS_TAG_CONFIG = {
running: {
color: 'green',
labelKey: '运行中',
icon: <FaPlay size={12} className='text-green-600' />,
},
deploying: {
color: 'blue',
labelKey: '部署中',
icon: <FaSpinner size={12} className='text-blue-600' />,
},
pending: {
color: 'orange',
labelKey: '待部署',
icon: <FaClock size={12} className='text-orange-600' />,
},
stopped: {
color: 'grey',
labelKey: '已停止',
icon: <FaStop size={12} className='text-gray-500' />,
},
error: {
color: 'red',
labelKey: '错误',
icon: <FaExclamationCircle size={12} className='text-red-500' />,
},
failed: {
color: 'red',
labelKey: '失败',
icon: <FaExclamationCircle size={12} className='text-red-500' />,
},
destroyed: {
color: 'red',
labelKey: '已销毁',
icon: <FaBan size={12} className='text-red-500' />,
},
completed: {
color: 'green',
labelKey: '已完成',
icon: <FaCheckCircle size={12} className='text-green-600' />,
},
'deployment requested': {
color: 'blue',
labelKey: '部署请求中',
icon: <FaSpinner size={12} className='text-blue-600' />,
},
'termination requested': {
color: 'orange',
labelKey: '终止请求中',
icon: <FaClock size={12} className='text-orange-600' />,
},
};
const DEFAULT_STATUS_CONFIG = {
color: 'grey',
labelKey: null,
icon: <FaInfoCircle size={12} className='text-gray-500' />,
};
const parsePercentValue = (value) => {
if (value === null || value === undefined) return null;
if (typeof value === 'string') {
const parsed = parseFloat(value.replace(/[^0-9.+-]/g, ''));
return Number.isFinite(parsed) ? parsed : null;
}
if (typeof value === 'number') {
return Number.isFinite(value) ? value : null;
}
return null;
};
const clampPercent = (value) => {
if (value === null || value === undefined) return null;
return Math.min(100, Math.max(0, Math.round(value)));
};
const formatRemainingMinutes = (minutes, t) => {
if (minutes === null || minutes === undefined) return null;
const numeric = Number(minutes);
if (!Number.isFinite(numeric)) return null;
const totalMinutes = Math.max(0, Math.round(numeric));
const days = Math.floor(totalMinutes / 1440);
const hours = Math.floor((totalMinutes % 1440) / 60);
const mins = totalMinutes % 60;
const parts = [];
if (days > 0) {
parts.push(`${days}${t('天')}`);
}
if (hours > 0) {
parts.push(`${hours}${t('小时')}`);
}
if (parts.length === 0 || mins > 0) {
parts.push(`${mins}${t('分钟')}`);
}
return parts.join(' ');
};
const getRemainingTheme = (percentRemaining) => {
if (percentRemaining === null) {
return {
iconColor: 'var(--semi-color-primary)',
tagColor: 'blue',
textColor: 'var(--semi-color-text-2)',
};
}
if (percentRemaining <= 10) {
return {
iconColor: '#ff5a5f',
tagColor: 'red',
textColor: '#ff5a5f',
};
}
if (percentRemaining <= 30) {
return {
iconColor: '#ffb400',
tagColor: 'orange',
textColor: '#ffb400',
};
}
return {
iconColor: '#2ecc71',
tagColor: 'green',
textColor: '#2ecc71',
};
};
const renderStatus = (status, t) => {
const normalizedStatus = normalizeStatus(status);
const config = STATUS_TAG_CONFIG[normalizedStatus] || DEFAULT_STATUS_CONFIG;
const statusText = typeof status === 'string' ? status : '';
const labelText = config.labelKey
? t(config.labelKey)
: statusText || t('未知状态');
return (
<Tag
color={config.color}
shape='circle'
size='small'
prefixIcon={config.icon}
>
{labelText}
</Tag>
);
};
// Container Name Cell Component - to properly handle React hooks
const ContainerNameCell = ({ text, record, t }) => {
const handleCopyId = async () => {
try {
await navigator.clipboard.writeText(record.id);
showSuccess(t('已复制 ID 到剪贴板'));
} catch (err) {
showError(t('复制失败'));
}
};
return (
<div className='flex flex-col gap-1'>
<Typography.Text strong className='text-base'>
{text}
</Typography.Text>
<Typography.Text
type='secondary'
size='small'
className='text-xs cursor-pointer hover:text-blue-600 transition-colors select-all'
onClick={handleCopyId}
title={t('点击复制ID')}
>
ID: {record.id}
</Typography.Text>
</div>
);
};
// Render resource configuration
const renderResourceConfig = (resource, t) => {
if (!resource) return '-';
const { cpu, memory, gpu } = resource;
return (
<div className='flex flex-col gap-1'>
{cpu && (
<div className='flex items-center gap-1 text-xs'>
<FaMicrochip className='text-blue-500' />
<span>CPU: {cpu}</span>
</div>
)}
{memory && (
<div className='flex items-center gap-1 text-xs'>
<FaMemory className='text-green-500' />
<span>内存: {memory}</span>
</div>
)}
{gpu && (
<div className='flex items-center gap-1 text-xs'>
<FaServer className='text-purple-500' />
<span>GPU: {gpu}</span>
</div>
)}
</div>
);
};
// Render instance count with status indicator
const renderInstanceCount = (count, record, t) => {
const normalizedStatus = normalizeStatus(record?.status);
const statusConfig = STATUS_TAG_CONFIG[normalizedStatus];
const countColor = statusConfig?.color ?? 'grey';
return (
<Tag color={countColor} size='small' shape='circle'>
{count || 0} {t('个实例')}
</Tag>
);
};
// Main function to get all deployment columns
export const getDeploymentsColumns = ({
t,
COLUMN_KEYS,
startDeployment,
restartDeployment,
deleteDeployment,
setEditingDeployment,
setShowEdit,
refresh,
activePage,
deployments,
// New handlers for enhanced operations
onViewLogs,
onExtendDuration,
onViewDetails,
onUpdateConfig,
onSyncToChannel,
}) => {
const columns = [
{
title: t('容器名称'),
dataIndex: 'container_name',
key: COLUMN_KEYS.container_name,
width: 300,
ellipsis: true,
render: (text, record) => (
<ContainerNameCell text={text} record={record} t={t} />
),
},
{
title: t('状态'),
dataIndex: 'status',
key: COLUMN_KEYS.status,
width: 140,
render: (status) => (
<div className='flex items-center gap-2'>{renderStatus(status, t)}</div>
),
},
{
title: t('服务商'),
dataIndex: 'provider',
key: COLUMN_KEYS.provider,
width: 140,
render: (provider) =>
provider ? (
<div
className='flex items-center gap-1.5 rounded-full border px-2 py-0.5 text-[10px] font-medium uppercase tracking-wide'
style={{
borderColor: 'rgba(59, 130, 246, 0.4)',
backgroundColor: 'rgba(59, 130, 246, 0.08)',
color: '#2563eb',
}}
>
<FaGlobe className='text-[11px]' />
<span>{provider}</span>
</div>
) : (
<Typography.Text
type='tertiary'
size='small'
className='text-xs text-gray-500'
>
{t('暂无')}
</Typography.Text>
),
},
{
title: t('剩余时间'),
dataIndex: 'time_remaining',
key: COLUMN_KEYS.time_remaining,
width: 200,
render: (text, record) => {
const normalizedStatus = normalizeStatus(record?.status);
const percentUsedRaw = parsePercentValue(record?.completed_percent);
const percentUsed = clampPercent(percentUsedRaw);
const percentRemaining =
percentUsed === null ? null : clampPercent(100 - percentUsed);
const theme = getRemainingTheme(percentRemaining);
const statusDisplayMap = {
completed: t('已完成'),
destroyed: t('已销毁'),
failed: t('失败'),
error: t('失败'),
stopped: t('已停止'),
pending: t('待部署'),
deploying: t('部署中'),
'deployment requested': t('部署请求中'),
'termination requested': t('终止中'),
};
const statusOverride = statusDisplayMap[normalizedStatus];
const baseTimeDisplay =
text && String(text).trim() !== '' ? text : t('计算中');
const timeDisplay = baseTimeDisplay;
const humanReadable = formatRemainingMinutes(
record.compute_minutes_remaining,
t,
);
const showProgress = !statusOverride && normalizedStatus === 'running';
const showExtraInfo = Boolean(humanReadable || percentUsed !== null);
const showRemainingMeta =
record.compute_minutes_remaining !== undefined &&
record.compute_minutes_remaining !== null &&
percentRemaining !== null;
return (
<div className='flex flex-col gap-1 leading-tight text-xs'>
<div className='flex items-center gap-1.5'>
<FaHourglassHalf
className='text-sm'
style={{ color: theme.iconColor }}
/>
<Typography.Text className='text-sm font-medium text-[var(--semi-color-text-0)]'>
{timeDisplay}
</Typography.Text>
{showProgress && percentRemaining !== null ? (
<Tag size='small' color={theme.tagColor}>
{percentRemaining}%
</Tag>
) : statusOverride ? (
<Tag size='small' color='grey'>
{statusOverride}
</Tag>
) : null}
</div>
{showExtraInfo && (
<div className='flex items-center gap-3 text-[var(--semi-color-text-2)]'>
{humanReadable && (
<span className='flex items-center gap-1'>
<FaClock className='text-[11px]' />
{t('约')} {humanReadable}
</span>
)}
{percentUsed !== null && (
<span className='flex items-center gap-1'>
<FaCheckCircle className='text-[11px]' />
{t('已用')} {percentUsed}%
</span>
)}
</div>
)}
{showProgress && showRemainingMeta && (
<div className='text-[10px]' style={{ color: theme.textColor }}>
{t('剩余')} {record.compute_minutes_remaining} {t('分钟')}
</div>
)}
</div>
);
},
},
{
title: t('硬件配置'),
dataIndex: 'hardware_info',
key: COLUMN_KEYS.hardware_info,
width: 220,
ellipsis: true,
render: (text, record) => (
<div className='flex items-center gap-2'>
<div className='flex items-center gap-1 px-2 py-1 bg-green-50 border border-green-200 rounded-md'>
<FaServer className='text-green-600 text-xs' />
<span className='text-xs font-medium text-green-700'>
{record.hardware_name}
</span>
</div>
<span className='text-xs text-gray-500 font-medium'>
x{record.hardware_quantity}
</span>
</div>
),
},
{
title: t('创建时间'),
dataIndex: 'created_at',
key: COLUMN_KEYS.created_at,
width: 150,
render: (text) => (
<span className='text-sm text-gray-600'>{timestamp2string(text)}</span>
),
},
{
title: t('操作'),
key: COLUMN_KEYS.actions,
fixed: 'right',
width: 120,
render: (_, record) => {
const { status, id } = record;
const normalizedStatus = normalizeStatus(status);
const isEnded =
normalizedStatus === 'completed' || normalizedStatus === 'destroyed';
const handleDelete = () => {
// Use enhanced confirmation dialog
onUpdateConfig?.(record, 'delete');
};
// Get primary action based on status
const getPrimaryAction = () => {
switch (normalizedStatus) {
case 'running':
return {
icon: <FaInfoCircle className='text-xs' />,
text: t('查看详情'),
onClick: () => onViewDetails?.(record),
type: 'secondary',
theme: 'borderless',
};
case 'failed':
case 'error':
return {
icon: <FaPlay className='text-xs' />,
text: t('重试'),
onClick: () => startDeployment(id),
type: 'primary',
theme: 'solid',
};
case 'stopped':
return {
icon: <FaPlay className='text-xs' />,
text: t('启动'),
onClick: () => startDeployment(id),
type: 'primary',
theme: 'solid',
};
case 'deployment requested':
case 'deploying':
return {
icon: <FaClock className='text-xs' />,
text: t('部署中'),
onClick: () => {},
type: 'secondary',
theme: 'light',
disabled: true,
};
case 'pending':
return {
icon: <FaClock className='text-xs' />,
text: t('待部署'),
onClick: () => {},
type: 'secondary',
theme: 'light',
disabled: true,
};
case 'termination requested':
return {
icon: <FaClock className='text-xs' />,
text: t('终止中'),
onClick: () => {},
type: 'secondary',
theme: 'light',
disabled: true,
};
case 'completed':
case 'destroyed':
default:
return {
icon: <FaInfoCircle className='text-xs' />,
text: t('已结束'),
onClick: () => {},
type: 'tertiary',
theme: 'borderless',
disabled: true,
};
}
};
const primaryAction = getPrimaryAction();
const primaryTheme = primaryAction.theme || 'solid';
const primaryType = primaryAction.type || 'primary';
if (isEnded) {
return (
<div className='flex w-full items-center justify-start gap-1 pr-2'>
<Button
size='small'
type='tertiary'
theme='borderless'
onClick={() => onViewDetails?.(record)}
icon={<FaInfoCircle className='text-xs' />}
>
{t('查看详情')}
</Button>
</div>
);
}
// All actions dropdown with enhanced operations
const dropdownItems = [
<Dropdown.Item
key='details'
onClick={() => onViewDetails?.(record)}
icon={<FaInfoCircle />}
>
{t('查看详情')}
</Dropdown.Item>,
];
if (!isEnded) {
dropdownItems.push(
<Dropdown.Item
key='logs'
onClick={() => onViewLogs?.(record)}
icon={<FaTerminal />}
>
{t('查看日志')}
</Dropdown.Item>,
);
}
const managementItems = [];
if (normalizedStatus === 'running') {
if (onSyncToChannel) {
managementItems.push(
<Dropdown.Item
key='sync-channel'
onClick={() => onSyncToChannel(record)}
icon={<FaLink />}
>
{t('同步到渠道')}
</Dropdown.Item>,
);
}
}
if (normalizedStatus === 'failed' || normalizedStatus === 'error') {
managementItems.push(
<Dropdown.Item
key='retry'
onClick={() => startDeployment(id)}
icon={<FaPlay />}
>
{t('重试')}
</Dropdown.Item>,
);
}
if (normalizedStatus === 'stopped') {
managementItems.push(
<Dropdown.Item
key='start'
onClick={() => startDeployment(id)}
icon={<FaPlay />}
>
{t('启动')}
</Dropdown.Item>,
);
}
if (managementItems.length > 0) {
dropdownItems.push(<Dropdown.Divider key='management-divider' />);
dropdownItems.push(...managementItems);
}
const configItems = [];
if (
!isEnded &&
(normalizedStatus === 'running' ||
normalizedStatus === 'deployment requested')
) {
configItems.push(
<Dropdown.Item
key='extend'
onClick={() => onExtendDuration?.(record)}
icon={<FaPlus />}
>
{t('延长时长')}
</Dropdown.Item>,
);
}
// if (!isEnded && normalizedStatus === 'running') {
// configItems.push(
// <Dropdown.Item key="update-config" onClick={() => onUpdateConfig?.(record)} icon={<FaCog />}>
// {t('更新配置')}
// </Dropdown.Item>,
// );
// }
if (configItems.length > 0) {
dropdownItems.push(<Dropdown.Divider key='config-divider' />);
dropdownItems.push(...configItems);
}
if (!isEnded) {
dropdownItems.push(<Dropdown.Divider key='danger-divider' />);
dropdownItems.push(
<Dropdown.Item
key='delete'
type='danger'
onClick={handleDelete}
icon={<FaTrash />}
>
{t('销毁容器')}
</Dropdown.Item>,
);
}
const allActions = <Dropdown.Menu>{dropdownItems}</Dropdown.Menu>;
const hasDropdown = dropdownItems.length > 0;
return (
<div className='flex w-full items-center justify-start gap-1 pr-2'>
<Button
size='small'
theme={primaryTheme}
type={primaryType}
icon={primaryAction.icon}
onClick={primaryAction.onClick}
className='px-2 text-xs'
disabled={primaryAction.disabled}
>
{primaryAction.text}
</Button>
{hasDropdown && (
<Dropdown
trigger='click'
position='bottomRight'
render={allActions}
>
<Button
size='small'
theme='light'
type='tertiary'
icon={<IconMore />}
className='px-1'
/>
</Dropdown>
)}
</div>
);
},
},
];
return columns;
};

View File

@@ -0,0 +1,130 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useRef } from 'react';
import { Form, Button } from '@douyinfe/semi-ui';
import { IconSearch, IconRefresh } from '@douyinfe/semi-icons';
const DeploymentsFilters = ({
formInitValues,
setFormApi,
searchDeployments,
loading,
searching,
setShowColumnSelector,
t,
}) => {
const formApiRef = useRef(null);
const handleSubmit = (values) => {
searchDeployments(values);
};
const handleReset = () => {
if (!formApiRef.current) return;
formApiRef.current.reset();
setTimeout(() => {
formApiRef.current.submitForm();
}, 0);
};
const statusOptions = [
{ label: t('全部状态'), value: '' },
{ label: t('运行中'), value: 'running' },
{ label: t('已完成'), value: 'completed' },
{ label: t('失败'), value: 'failed' },
{ label: t('部署请求中'), value: 'deployment requested' },
{ label: t('终止请求中'), value: 'termination requested' },
{ label: t('已销毁'), value: 'destroyed' },
];
return (
<Form
layout='horizontal'
onSubmit={handleSubmit}
initValues={formInitValues}
getFormApi={(formApi) => {
setFormApi(formApi);
formApiRef.current = formApi;
}}
className='w-full md:w-auto order-1 md:order-2'
>
<div className='flex flex-col md:flex-row items-center gap-2 w-full md:w-auto'>
<div className='w-full md:w-64'>
<Form.Input
field='searchKeyword'
placeholder={t('搜索部署名称')}
prefix={<IconSearch />}
showClear
size='small'
pure
/>
</div>
<div className='w-full md:w-48'>
<Form.Select
field='searchStatus'
placeholder={t('选择状态')}
optionList={statusOptions}
className='w-full'
showClear
size='small'
pure
/>
</div>
<div className='flex gap-2 w-full md:w-auto'>
<Button
htmlType='submit'
type='tertiary'
icon={<IconSearch />}
loading={searching}
disabled={loading}
size='small'
className='flex-1 md:flex-initial md:w-auto'
>
{t('查询')}
</Button>
<Button
type='tertiary'
icon={<IconRefresh />}
onClick={handleReset}
disabled={loading || searching}
size='small'
className='flex-1 md:flex-initial md:w-auto'
>
{t('重置')}
</Button>
<Button
type='tertiary'
onClick={() => setShowColumnSelector(true)}
size='small'
className='flex-1 md:flex-initial md:w-auto'
>
{t('列设置')}
</Button>
</div>
</div>
</Form>
);
};
export default DeploymentsFilters;

View File

@@ -0,0 +1,247 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useMemo, useState } from 'react';
import { Empty } from '@douyinfe/semi-ui';
import CardTable from '../../common/ui/CardTable';
import {
IllustrationNoResult,
IllustrationNoResultDark,
} from '@douyinfe/semi-illustrations';
import { getDeploymentsColumns } from './DeploymentsColumnDefs';
// Import all the new modals
import ViewLogsModal from './modals/ViewLogsModal';
import ExtendDurationModal from './modals/ExtendDurationModal';
import ViewDetailsModal from './modals/ViewDetailsModal';
import UpdateConfigModal from './modals/UpdateConfigModal';
import ConfirmationDialog from './modals/ConfirmationDialog';
const DeploymentsTable = (deploymentsData) => {
const {
deployments,
loading,
searching,
activePage,
pageSize,
deploymentCount,
compactMode,
visibleColumns,
rowSelection,
batchOperationsEnabled = true,
handlePageChange,
handlePageSizeChange,
handleRow,
t,
COLUMN_KEYS,
// Column functions and data
startDeployment,
restartDeployment,
deleteDeployment,
syncDeploymentToChannel,
setEditingDeployment,
setShowEdit,
refresh,
} = deploymentsData;
// Modal states
const [selectedDeployment, setSelectedDeployment] = useState(null);
const [showLogsModal, setShowLogsModal] = useState(false);
const [showExtendModal, setShowExtendModal] = useState(false);
const [showDetailsModal, setShowDetailsModal] = useState(false);
const [showConfigModal, setShowConfigModal] = useState(false);
const [showConfirmDialog, setShowConfirmDialog] = useState(false);
const [confirmOperation, setConfirmOperation] = useState('delete');
// Enhanced modal handlers
const handleViewLogs = (deployment) => {
setSelectedDeployment(deployment);
setShowLogsModal(true);
};
const handleExtendDuration = (deployment) => {
setSelectedDeployment(deployment);
setShowExtendModal(true);
};
const handleViewDetails = (deployment) => {
setSelectedDeployment(deployment);
setShowDetailsModal(true);
};
const handleUpdateConfig = (deployment, operation = 'update') => {
setSelectedDeployment(deployment);
if (operation === 'delete' || operation === 'destroy') {
setConfirmOperation(operation);
setShowConfirmDialog(true);
} else {
setShowConfigModal(true);
}
};
const handleConfirmAction = () => {
if (
selectedDeployment &&
(confirmOperation === 'delete' || confirmOperation === 'destroy')
) {
deleteDeployment(selectedDeployment.id);
}
setShowConfirmDialog(false);
setSelectedDeployment(null);
};
const handleModalSuccess = (updatedDeployment) => {
// Refresh the deployments list
refresh?.();
};
// Get all columns
const allColumns = useMemo(() => {
return getDeploymentsColumns({
t,
COLUMN_KEYS,
startDeployment,
restartDeployment,
deleteDeployment,
setEditingDeployment,
setShowEdit,
refresh,
activePage,
deployments,
// Enhanced handlers
onViewLogs: handleViewLogs,
onExtendDuration: handleExtendDuration,
onViewDetails: handleViewDetails,
onUpdateConfig: handleUpdateConfig,
onSyncToChannel: syncDeploymentToChannel,
});
}, [
t,
COLUMN_KEYS,
startDeployment,
restartDeployment,
deleteDeployment,
syncDeploymentToChannel,
setEditingDeployment,
setShowEdit,
refresh,
activePage,
deployments,
]);
// Filter columns based on visibility settings
const getVisibleColumns = () => {
return allColumns.filter((column) => visibleColumns[column.key]);
};
const visibleColumnsList = useMemo(() => {
return getVisibleColumns();
}, [visibleColumns, allColumns]);
const tableColumns = useMemo(() => {
if (compactMode) {
// In compact mode, remove fixed columns and adjust widths
return visibleColumnsList.map(({ fixed, width, ...rest }) => ({
...rest,
width: width ? Math.max(width * 0.8, 80) : undefined, // Reduce width by 20% but keep minimum
}));
}
return visibleColumnsList;
}, [compactMode, visibleColumnsList]);
return (
<>
<CardTable
columns={tableColumns}
dataSource={deployments}
scroll={compactMode ? { x: 800 } : { x: 1200 }}
pagination={{
currentPage: activePage,
pageSize: pageSize,
total: deploymentCount,
pageSizeOpts: [10, 20, 50, 100],
showSizeChanger: true,
onPageSizeChange: handlePageSizeChange,
onPageChange: handlePageChange,
}}
hidePagination={true}
expandAllRows={false}
onRow={handleRow}
rowSelection={batchOperationsEnabled ? rowSelection : undefined}
empty={
<Empty
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
darkModeImage={
<IllustrationNoResultDark style={{ width: 150, height: 150 }} />
}
description={t('搜索无结果')}
style={{ padding: 30 }}
/>
}
className='rounded-xl overflow-hidden'
size='middle'
loading={loading || searching}
/>
{/* Enhanced Modals */}
<ViewLogsModal
visible={showLogsModal}
onCancel={() => setShowLogsModal(false)}
deployment={selectedDeployment}
t={t}
/>
<ExtendDurationModal
visible={showExtendModal}
onCancel={() => setShowExtendModal(false)}
deployment={selectedDeployment}
onSuccess={handleModalSuccess}
t={t}
/>
<ViewDetailsModal
visible={showDetailsModal}
onCancel={() => setShowDetailsModal(false)}
deployment={selectedDeployment}
t={t}
/>
<UpdateConfigModal
visible={showConfigModal}
onCancel={() => setShowConfigModal(false)}
deployment={selectedDeployment}
onSuccess={handleModalSuccess}
t={t}
/>
<ConfirmationDialog
visible={showConfirmDialog}
onCancel={() => setShowConfirmDialog(false)}
onConfirm={handleConfirmAction}
title={t('确认操作')}
type='danger'
deployment={selectedDeployment}
operation={confirmOperation}
t={t}
/>
</>
);
};
export default DeploymentsTable;

View File

@@ -0,0 +1,152 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState } from 'react';
import CardPro from '../../common/ui/CardPro';
import DeploymentsTable from './DeploymentsTable';
import DeploymentsActions from './DeploymentsActions';
import DeploymentsFilters from './DeploymentsFilters';
import EditDeploymentModal from './modals/EditDeploymentModal';
import CreateDeploymentModal from './modals/CreateDeploymentModal';
import ColumnSelectorModal from './modals/ColumnSelectorModal';
import { useDeploymentsData } from '../../../hooks/model-deployments/useDeploymentsData';
import { useIsMobile } from '../../../hooks/common/useIsMobile';
import { createCardProPagination } from '../../../helpers/utils';
const DeploymentsPage = () => {
const deploymentsData = useDeploymentsData();
const isMobile = useIsMobile();
// Create deployment modal state
const [showCreateModal, setShowCreateModal] = useState(false);
const batchOperationsEnabled = false;
const {
// Edit state
showEdit,
editingDeployment,
closeEdit,
refresh,
// Actions state
selectedKeys,
setSelectedKeys,
setEditingDeployment,
setShowEdit,
batchDeleteDeployments,
// Filters state
formInitValues,
setFormApi,
searchDeployments,
loading,
searching,
// Column visibility
showColumnSelector,
setShowColumnSelector,
visibleColumns,
setVisibleColumns,
COLUMN_KEYS,
// Description state
compactMode,
setCompactMode,
// Translation
t,
} = deploymentsData;
return (
<>
{/* Modals */}
<EditDeploymentModal
refresh={refresh}
editingDeployment={editingDeployment}
visible={showEdit}
handleClose={closeEdit}
/>
<CreateDeploymentModal
visible={showCreateModal}
onCancel={() => setShowCreateModal(false)}
onSuccess={refresh}
t={t}
/>
<ColumnSelectorModal
visible={showColumnSelector}
onCancel={() => setShowColumnSelector(false)}
visibleColumns={visibleColumns}
onVisibleColumnsChange={setVisibleColumns}
columnKeys={COLUMN_KEYS}
t={t}
/>
{/* Main Content */}
<CardPro
type='type3'
actionsArea={
<div className='flex flex-col md:flex-row justify-between items-center gap-2 w-full'>
<DeploymentsActions
selectedKeys={selectedKeys}
setSelectedKeys={setSelectedKeys}
setEditingDeployment={setEditingDeployment}
setShowEdit={setShowEdit}
batchDeleteDeployments={batchDeleteDeployments}
batchOperationsEnabled={batchOperationsEnabled}
compactMode={compactMode}
setCompactMode={setCompactMode}
showCreateModal={showCreateModal}
setShowCreateModal={setShowCreateModal}
setShowColumnSelector={setShowColumnSelector}
t={t}
/>
<DeploymentsFilters
formInitValues={formInitValues}
setFormApi={setFormApi}
searchDeployments={searchDeployments}
loading={loading}
searching={searching}
setShowColumnSelector={setShowColumnSelector}
t={t}
/>
</div>
}
paginationArea={createCardProPagination({
currentPage: deploymentsData.activePage,
pageSize: deploymentsData.pageSize,
total: deploymentsData.deploymentCount,
onPageChange: deploymentsData.handlePageChange,
onPageSizeChange: deploymentsData.handlePageSizeChange,
isMobile: isMobile,
t: deploymentsData.t,
})}
t={deploymentsData.t}
>
<DeploymentsTable
{...deploymentsData}
batchOperationsEnabled={batchOperationsEnabled}
/>
</CardPro>
</>
);
};
export default DeploymentsPage;

View File

@@ -0,0 +1,127 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useMemo } from 'react';
import { Modal, Button, Checkbox } from '@douyinfe/semi-ui';
const ColumnSelectorModal = ({
visible,
onCancel,
visibleColumns,
onVisibleColumnsChange,
columnKeys,
t,
}) => {
const columnOptions = useMemo(
() => [
{ key: columnKeys.container_name, label: t('容器名称'), required: true },
{ key: columnKeys.status, label: t('状态') },
{ key: columnKeys.time_remaining, label: t('剩余时间') },
{ key: columnKeys.hardware_info, label: t('硬件配置') },
{ key: columnKeys.created_at, label: t('创建时间') },
{ key: columnKeys.actions, label: t('操作'), required: true },
],
[columnKeys, t],
);
const handleColumnVisibilityChange = (key, checked) => {
const column = columnOptions.find((option) => option.key === key);
if (column?.required) return;
onVisibleColumnsChange({
...visibleColumns,
[key]: checked,
});
};
const handleSelectAll = (checked) => {
const updated = { ...visibleColumns };
columnOptions.forEach(({ key, required }) => {
updated[key] = required ? true : checked;
});
onVisibleColumnsChange(updated);
};
const handleReset = () => {
const defaults = columnOptions.reduce((acc, { key }) => {
acc[key] = true;
return acc;
}, {});
onVisibleColumnsChange({
...visibleColumns,
...defaults,
});
};
const allSelected = columnOptions.every(
({ key, required }) => required || visibleColumns[key],
);
const indeterminate =
columnOptions.some(
({ key, required }) => !required && visibleColumns[key],
) && !allSelected;
const handleConfirm = () => onCancel();
return (
<Modal
title={t('列设置')}
visible={visible}
onCancel={onCancel}
footer={
<div className='flex justify-end gap-2'>
<Button onClick={handleReset}>{t('重置')}</Button>
<Button onClick={onCancel}>{t('取消')}</Button>
<Button type='primary' onClick={handleConfirm}>
{t('确定')}
</Button>
</div>
}
>
<div style={{ marginBottom: 20 }}>
<Checkbox
checked={allSelected}
indeterminate={indeterminate}
onChange={(e) => handleSelectAll(e.target.checked)}
>
{t('全选')}
</Checkbox>
</div>
<div
className='flex flex-wrap max-h-96 overflow-y-auto rounded-lg p-4'
style={{ border: '1px solid var(--semi-color-border)' }}
>
{columnOptions.map(({ key, label, required }) => (
<div key={key} className='w-1/2 mb-4 pr-2'>
<Checkbox
checked={!!visibleColumns[key]}
disabled={required}
onChange={(e) =>
handleColumnVisibilityChange(key, e.target.checked)
}
>
{label}
</Checkbox>
</div>
))}
</div>
</Modal>
);
};
export default ColumnSelectorModal;

Some files were not shown because too many files have changed in this diff Show More