Compare commits

...

90 Commits

Author SHA1 Message Date
Calcium-Ion
9c079d04a8 Merge pull request #1487 from seefs001/feature/2fa
feat: implement two-factor authentication (2FA) support with user login and settings integration
2025-08-04 19:54:31 +08:00
Calcium-Ion
c9d4cdc57e Merge pull request #1498 from QuantumNous/multi-key-manage
feat: add multi-key management
2025-08-04 19:53:52 +08:00
CaIon
12b4e80d4b feat: add status filtering and bulk enable/disable functionality in multi-key management 2025-08-04 19:51:58 +08:00
CaIon
6e2a04f374 fix: correct option value for pagination in MultiKeyManageModal 2025-08-04 19:33:24 +08:00
CaIon
8357b15fec feat: enhance multi-key management with pagination and statistics 2025-08-04 17:15:32 +08:00
CaIon
ecdd9d1ccb feat: add multi-key management 2025-08-04 16:52:31 +08:00
Xyfacai
10b04416c1 fix: 修复gemini2openai 没有返回 usage 2025-08-04 09:06:57 +08:00
Seefs
398ae7156b refactor: improve error handling and database transactions in 2FA model methods 2025-08-03 10:49:55 +08:00
Seefs
d85eeabf11 fix: coderabbit review 2025-08-03 10:41:00 +08:00
CaIon
c056a7ad7c feat: add support for multi-key channels in RelayInfo and access token caching 2025-08-02 22:12:15 +08:00
Seefs
c784a70277 feat: implement two-factor authentication (2FA) support with user login and settings integration 2025-08-02 14:53:28 +08:00
Calcium-Ion
e6c87907d5 Merge pull request #1486 from nekohy/fix-get-google-models
fix: correct Gemini channel model retrieval logic
2025-08-02 14:52:22 +08:00
Nekohy
71e9290142 fix: correct Gemini channel model retrieval logic 2025-08-02 14:19:32 +08:00
CaIon
74ec34da67 fix: improve error handling and readability in ability.go 2025-08-02 14:06:12 +08:00
CaIon
7188749cb3 feat: truncate abilities table before processing channels 2025-08-02 13:39:53 +08:00
CaIon
c28add55db feat: add caching for keys in channel structure and retain polling index during sync 2025-08-02 13:16:30 +08:00
CaIon
78f34a8245 feat: retain polling index for multi-key channels during sync 2025-08-02 13:04:48 +08:00
CaIon
97d6f10f15 feat: enhance ConvertGeminiRequest to set default role and handle YouTube video MIME type 2025-08-02 12:53:58 +08:00
Calcium-Ion
afefc4caca Merge pull request #1484 from QuantumNous/ConvertGeminiRequest
feat: Convert gemini request
2025-08-02 12:20:39 +08:00
CaIon
6abbd036f8 feat: add recordErrorLog option to NewAPIError for conditional error logging 2025-08-02 11:07:50 +08:00
CaIon
ef0db0f914 feat: implement key mode for multi-key channels with append/replace options 2025-08-02 10:57:03 +08:00
creamlike1024
e01986fdd4 Merge remote-tracking branch 'origin/alpha' into ConvertGeminiRequest 2025-08-01 22:42:48 +08:00
creamlike1024
a0c6ebe2d8 chore: remove debug log 2025-08-01 22:29:19 +08:00
creamlike1024
d2183af23f feat: convert gemini format to openai chat completions 2025-08-01 22:23:35 +08:00
CaIon
953f1bdc3c feat: add admin info to error logging with multi-key support 2025-08-01 18:19:28 +08:00
CaIon
e2429f20f8 fix: ensure ChannelIsMultiKey context key is set to false for single key retries 2025-08-01 18:09:20 +08:00
CaIon
f0945da4fb refactor: simplify streamResponseGeminiChat2OpenAI by removing hasImage return value and optimizing response text handling 2025-08-01 17:58:21 +08:00
CaIon
8df3de9ae5 fix: update JSONEditor to default to manual mode for invalid JSON and add error message for invalid data 2025-08-01 17:21:25 +08:00
Calcium-Ion
277cc1cac8 Merge pull request #1481 from seefs001/revert-1445-feature/claude-code
Revert "feat: add Claude Code channel support with OAuth integration"
2025-08-01 17:05:22 +08:00
CaIon
07a92293e4 fix: handle case where no response is received from Gemini API 2025-08-01 17:04:16 +08:00
Seefs
f995e31d04 Revert "feat: add Claude Code channel support with OAuth integration" 2025-07-31 22:08:16 +08:00
Calcium-Ion
9758a9e60d Merge pull request #1445 from seefs001/feature/claude-code
feat: add Claude Code channel support with OAuth integration
2025-07-31 21:28:23 +08:00
Seefs
6f56696af2 fix: handle authorization code format in ExchangeCode function and update placeholder in EditChannelModal 2025-07-31 21:27:24 +08:00
Seefs
345fbdf3d2 Merge branch 'alpha' into feature/claude-code
# Conflicts:
#	web/src/components/table/channels/modals/EditChannelModal.jsx
2025-07-31 21:19:43 +08:00
CaIon
ce031f7d15 refactor: update error handling to support dynamic error types 2025-07-31 21:16:01 +08:00
CaIon
bd6b811183 feat: add JSONEditor component for enhanced JSON input handling 2025-07-31 12:54:07 +08:00
CaIon
196bafff03 fix: 修复被禁用的渠道无法测试的问题 2025-07-31 10:56:51 +08:00
CaIon
f20b558e22 fix: correct request mode assignment logic in adaptor 2025-07-30 23:32:20 +08:00
CaIon
54447bf227 fix: remove debug print statement 2025-07-30 23:29:45 +08:00
CaIon
fc09051d8b fix: 修复缓存开启下自动禁用失效 2025-07-30 23:26:09 +08:00
Xyfacai
1f5ef24ecd feat: 显式指定 error 跳过重试 2025-07-30 22:35:31 +08:00
creamlike1024
b1faf42529 Merge branch 'RedwindA-fix/gemini-native-sse' into alpha 2025-07-30 20:37:10 +08:00
creamlike1024
6a85206e32 Merge branch 'fix/gemini-native-sse' of github.com:RedwindA/new-api into RedwindA-fix/gemini-native-sse 2025-07-30 20:34:12 +08:00
CaIon
e3d3e697d3 fix: WriteContentType panic 2025-07-30 20:31:51 +08:00
IcedTangerine
db9b333930 Merge pull request #1405 from RedwindA/fix/gemini-nothinking-handler
fix: improve gemini nothinking handler
2025-07-30 20:31:26 +08:00
CaIon
f7b284ad73 feat: 错误内容脱敏 2025-07-30 19:08:35 +08:00
CaIon
e1970e8a66 Merge remote-tracking branch 'origin/alpha' into alpha 2025-07-30 18:39:32 +08:00
CaIon
0cd93d67ff fix: auto ban 2025-07-30 18:39:19 +08:00
IcedTangerine
6e806e21bd Merge pull request #1472 from QuantumNous/revert-1385-patch-1
Revert 1385 patch 1
2025-07-30 12:19:57 +08:00
IcedTangerine
a8462c1b70 Revert "Update relay-claude.go" 2025-07-30 12:17:56 +08:00
IcedTangerine
706ea8b649 Merge pull request #1385 from QingyeSC/patch-1
Update claude topP argument
2025-07-30 00:12:57 +08:00
CaIon
95d46d1dfc fix: auto ban 2025-07-29 23:08:16 +08:00
CaIon
010f27678d fix: auto ban 2025-07-29 15:20:08 +08:00
t0ng7u
d87117a2cf Merge remote-tracking branch 'origin/alpha' into alpha 2025-07-28 01:33:36 +08:00
t0ng7u
4ed92a94a1 feat: Enhance Channel Model Management UI
Summary
• Introduced standalone `ModelSelectModal.jsx` for selecting channel models
• Fetch-list now opens modal instead of in-place select, keeping EditChannelModal lean

Modal Features
1. Search bar with `IconSearch`, keyboard clear & mobile full-screen support
2. Tab layout (“New Models” / “Existing Models”) displayed next to title, responsive wrapping
3. Models grouped by vendor via `getModelCategories` and rendered inside always-expanded `Collapse` panels
4. Per-category checkbox in panel extra area for bulk select / deselect
5. Footer checkbox for bulk select of all models in current tab, with real-time counter
6. Empty state uses `IllustrationNoResult` / `IllustrationNoResultDark` for visual consistency
7. Accessible header/footer paddings aligned with Semi UI defaults

Fixes & Improvements
• All indeterminate and full-select states handled correctly
• Consistent “selected X / Y” stats synced with active tab, not global list
• All panels now controlled via `activeKey`, ensuring they remain expanded
• Search, vendor grouping, and responsive layout tested across mobile & desktop

These changes modernise the channel model management workflow and prepare the codebase for upcoming upstream-ratio integration.
2025-07-28 01:33:23 +08:00
CaIon
821ea34a3c fix: increase maximum request count limits for user rate settings 2025-07-27 16:46:34 +08:00
Calcium-Ion
ecb3d01376 Merge pull request #1446 from Raymondxox/fix
模型请求速率限制,增加对请求次数最大值的限制
2025-07-27 16:33:49 +08:00
CaIon
e322ed4f05 fix: ensure minimum quota display and handle zero values in render function 2025-07-27 16:32:14 +08:00
Raymond
bcf7e78665 模型请求速率限制,增加对请求次数最大值的限制 2025-07-27 16:01:59 +08:00
t0ng7u
0cb2bb2ea7 🗂️ refactor(table): isolate column preferences per role
Summary
• Added role-specific localStorage keys for column visibility in three hooks:
  - `useUsageLogsData.js` → `logs-table-columns-admin` / `logs-table-columns-user`
  - `useMjLogsData.js`   → `mj-logs-table-columns-admin` / `mj-logs-table-columns-user`
  - `useTaskLogsData.js` → `task-logs-table-columns-admin` / `task-logs-table-columns-user`

Details
1. Each hook now derives a `STORAGE_KEY` based on `isAdminUser`, preventing admin and non-admin sessions from overwriting one another’s column settings.
2. Removed the previous “save but strip admin columns” workaround—settings are persisted unmodified to each role’s key.
3. Kept runtime behaviour: non-admin users still see admin-only columns forcibly hidden.
4. Replaced newly added Chinese comments with clear English equivalents for consistency.

Result
Switching between admin and non-admin accounts no longer corrupts column visibility preferences, and codebase comments are fully English-localized.
2025-07-27 09:49:57 +08:00
t0ng7u
c5d97597c4 🔍 fix: select search filter
Summary
• Introduced a unified `selectFilter` helper that matches both `option.value` and `option.label`, ensuring all `<Select>` components support intuitive search (fixes channel “type” dropdown not filtering).
• Replaced all usages of the old `modelSelectFilter` with `selectFilter` in:
  • `EditChannelModal.jsx`
  • `SettingsPanel.js`
  • `EditTokenModal.jsx`
  • `EditTagModal.jsx`
• Removed the deprecated `modelSelectFilter` export from `utils.js` (no backward-compat alias).
• Updated documentation comments accordingly.

Why
The old filter only inspected `option.value`, causing searches to fail when `label` carried the meaningful text (e.g., numeric IDs for channel types). The new helper searches both fields, covering all scenarios and unifying the API across the codebase.

Notes
No functional regressions expected; all components have been migrated.
2025-07-27 00:01:12 +08:00
Seefs
fe9acb6c59 chore: claude code automatic disable 2025-07-26 18:40:18 +08:00
Seefs
bca78beb1b feat: add claude code channel 2025-07-26 18:06:46 +08:00
t0ng7u
a8a42cbfa8 💄 style(ui): show "Force Format" toggle only for OpenAI channels
Previously, the "Force Format" switch was displayed for every channel type
although it only applies to OpenAI (type === 1).
This change wraps the switch in a conditional so it renders exclusively when
the selected channel type is OpenAI.

Why:
- Prevents user confusion when configuring non-OpenAI channels
- Keeps the UI clean and context-relevant

Scope:
- web/src/components/table/channels/modals/EditChannelModal.jsx

No backend logic affected.
2025-07-26 17:18:47 +08:00
Raymond
19df2ac234 模型请求速率限制,增加对请求次数最大值的限制 2025-07-26 17:09:38 +08:00
Calcium-Ion
e7524c85c2 Merge pull request #1443 from QuantumNous/claude_to_gemini
feat: Claude to gemini (适配claude格式调用gemini渠道模型)
2025-07-26 14:04:47 +08:00
Calcium-Ion
a4356727e9 Merge pull request #1437 from Raymondxox/fix
判断兑换码名称长度,改为按字符长度计算
2025-07-26 14:04:02 +08:00
t0ng7u
f15a53fae4 🎨 refactor(ui): redesign channel extra settings section in EditChannelModal
- Extract channel extra settings into a dedicated Card component for better visual hierarchy
- Replace custom gray background container with consistent Form component styling
- Simplify layout structure by removing complex Row/Col grid layout in favor of native Form component layout
- Unify help text styling by using extraText prop consistently across all form fields
- Move "Settings Documentation" link to card header subtitle for better accessibility
- Improve visual consistency with other setting cards by using matching design patterns

The channel extra settings (force format, thinking content conversion, pass-through body, proxy address, and system prompt) now follow the same design language as other configuration sections, providing a more cohesive user experience.

Affected settings:
- Force Format (OpenAI channels only)
- Thinking Content Conversion
- Pass-through Body
- Proxy Address
- System Prompt
2025-07-26 13:33:10 +08:00
CaIon
8e3cf2eaab feat: support claude convert to gemini 2025-07-26 13:31:33 +08:00
Calcium-Ion
c51ec3135b Merge pull request #1441 from QuantumNous/system_prompt
feat: 支持渠道级透传选项,支持设置渠道系统提示词
2025-07-26 12:16:16 +08:00
CaIon
2469c439b1 fix: improve error messaging and JSON schema handling in distributor and relay components 2025-07-26 12:11:20 +08:00
CaIon
1297addfb1 feat: enhance request handling with pass-through options and system prompt support 2025-07-26 11:39:09 +08:00
CaIon
d6cbf43373 Merge remote-tracking branch 'origin/alpha' into alpha 2025-07-26 10:43:42 +08:00
Raymond
df647e7b42 判断兑换码名称长度,改为按字符长度计算 2025-07-25 22:40:12 +08:00
t0ng7u
fe16d05fbb 🔒 fix: Enforce admin-only column visibility in logs tables
Ensure non-admin users cannot enable columns reserved for administrators
across the following hooks:

* web/src/hooks/usage-logs/useUsageLogsData.js
  - Force-hide CHANNEL, USERNAME and RETRY columns for non-admins.

* web/src/hooks/mj-logs/useMjLogsData.js
  - Force-hide CHANNEL and SUBMIT_RESULT columns for non-admins.

* web/src/hooks/task-logs/useTaskLogsData.js
  - Force-hide CHANNEL column for non-admins.

The checks run when loading column preferences from localStorage, overriding
any tampered settings to keep sensitive information hidden from
unauthorized users.
2025-07-25 20:31:20 +08:00
同語
1430c05b6c 🌟 fix: standardize and improve Playground Chat VIP group functionality
Merge pull request #1424 from feitianbubu/pr/fix-playground-chat-vip-group
2025-07-25 20:18:29 +08:00
CaIon
b25841e50d feat: add upstream error type and default handling for OpenAI and Claude errors 2025-07-25 18:48:59 +08:00
IcedTangerine
b704fc9254 Merge pull request #1425 from feitianbubu/pr/add-vidu-video-channel
feat: add vidu video channel
2025-07-24 12:14:04 +08:00
feitianbubu
352da66bd1 feat: add vidu video channel 2025-07-24 10:14:25 +08:00
feitianbubu
8205ad2cd0 fix: playground chat vip group 2025-07-24 09:38:00 +08:00
CaIon
e162b9c169 feat: support multi-key mode 2025-07-23 22:00:30 +08:00
CaIon
77e3502028 fix(adaptor): enhance response handling and error logging for Claude format 2025-07-23 20:59:56 +08:00
CaIon
ae0461692c feat: support ollama claude format 2025-07-23 20:01:03 +08:00
CaIon
13bdb80958 fix(adaptor): update relay mode handling #1419 2025-07-23 19:28:58 +08:00
CaIon
6f74e7b738 fix(adaptor): implement request conversion methods for Claude and Image. (close #1419) 2025-07-23 19:09:20 +08:00
CaIon
eaee89f77a fix(distributor): add validation for model name in channel selection 2025-07-23 16:46:06 +08:00
RedwindA
6103888610 fix: 修正nothinking判断逻辑,确保仅当预算为零时返回true 2025-07-20 17:35:34 +08:00
RedwindA
7af3fb5ae4 禁用原生Gemini模式中的ping保活 2025-07-18 23:39:01 +08:00
RedwindA
3ac54b2178 增加 DisablePing 字段以控制是否发送自定义 Ping 2025-07-18 23:38:35 +08:00
Glaxy
5621755655 Update relay-claude.go
优化claude messages接口启用思考时的参数设置
2025-07-16 18:00:33 +08:00
118 changed files with 6447 additions and 738 deletions

View File

@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"strings"
"sync"
)
type stringWriter interface {
@@ -52,6 +53,8 @@ type CustomEvent struct {
Id string
Retry uint
Data interface{}
Mutex sync.Mutex
}
func encode(writer io.Writer, event CustomEvent) error {
@@ -73,6 +76,8 @@ func (r CustomEvent) Render(w http.ResponseWriter) error {
}
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
r.Mutex.Lock()
defer r.Mutex.Unlock()
header := w.Header()
header["Content-Type"] = contentType

View File

@@ -4,7 +4,10 @@ import (
"encoding/base64"
"encoding/json"
"math/rand"
"net/url"
"regexp"
"strconv"
"strings"
"unsafe"
)
@@ -95,3 +98,95 @@ func GetJsonString(data any) string {
b, _ := json.Marshal(data)
return string(b)
}
// MaskSensitiveInfo masks sensitive information like URLs, IPs in a string
// Example:
// http://example.com -> http://***.com
// https://api.test.org/v1/users/123?key=secret -> https://***.org/***/***/?key=***
// https://sub.domain.co.uk/path/to/resource -> https://***.co.uk/***/***
// 192.168.1.1 -> ***.***.***.***
func MaskSensitiveInfo(str string) string {
// Mask URLs
urlPattern := regexp.MustCompile(`(http|https)://[^\s/$.?#].[^\s]*`)
str = urlPattern.ReplaceAllStringFunc(str, func(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
return urlStr
}
host := u.Host
if host == "" {
return urlStr
}
// Split host by dots
parts := strings.Split(host, ".")
if len(parts) < 2 {
// If less than 2 parts, just mask the whole host
return u.Scheme + "://***" + u.Path
}
// Keep the TLD (Top Level Domain) and mask the rest
var maskedHost string
if len(parts) == 2 {
// example.com -> ***.com
maskedHost = "***." + parts[len(parts)-1]
} else {
// Handle cases like sub.domain.co.uk or api.example.com
// Keep last 2 parts if they look like country code TLD (co.uk, com.cn, etc.)
lastPart := parts[len(parts)-1]
secondLastPart := parts[len(parts)-2]
if len(lastPart) == 2 && len(secondLastPart) <= 3 {
// Likely country code TLD like co.uk, com.cn
maskedHost = "***." + secondLastPart + "." + lastPart
} else {
// Regular TLD like .com, .org
maskedHost = "***." + lastPart
}
}
result := u.Scheme + "://" + maskedHost
// Mask path
if u.Path != "" && u.Path != "/" {
pathParts := strings.Split(strings.Trim(u.Path, "/"), "/")
maskedPathParts := make([]string, len(pathParts))
for i := range pathParts {
if pathParts[i] != "" {
maskedPathParts[i] = "***"
}
}
if len(maskedPathParts) > 0 {
result += "/" + strings.Join(maskedPathParts, "/")
}
} else if u.Path == "/" {
result += "/"
}
// Mask query parameters
if u.RawQuery != "" {
values, err := url.ParseQuery(u.RawQuery)
if err != nil {
// If can't parse query, just mask the whole query string
result += "?***"
} else {
maskedParams := make([]string, 0, len(values))
for key := range values {
maskedParams = append(maskedParams, key+"=***")
}
if len(maskedParams) > 0 {
result += "?" + strings.Join(maskedParams, "&")
}
}
}
return result
})
// Mask IP addresses
ipPattern := regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`)
str = ipPattern.ReplaceAllString(str, "***.***.***.***")
return str
}

150
common/totp.go Normal file
View File

@@ -0,0 +1,150 @@
package common
import (
"crypto/rand"
"fmt"
"os"
"strconv"
"strings"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
)
const (
// 备用码配置
BackupCodeLength = 8 // 备用码长度
BackupCodeCount = 4 // 生成备用码数量
// 限制配置
MaxFailAttempts = 5 // 最大失败尝试次数
LockoutDuration = 300 // 锁定时间(秒)
)
// GenerateTOTPSecret 生成TOTP密钥和配置
func GenerateTOTPSecret(accountName string) (*otp.Key, error) {
issuer := Get2FAIssuer()
return totp.Generate(totp.GenerateOpts{
Issuer: issuer,
AccountName: accountName,
Period: 30,
Digits: otp.DigitsSix,
Algorithm: otp.AlgorithmSHA1,
})
}
// ValidateTOTPCode 验证TOTP验证码
func ValidateTOTPCode(secret, code string) bool {
// 清理验证码格式
cleanCode := strings.ReplaceAll(code, " ", "")
if len(cleanCode) != 6 {
return false
}
// 验证验证码
return totp.Validate(cleanCode, secret)
}
// GenerateBackupCodes 生成备用恢复码
func GenerateBackupCodes() ([]string, error) {
codes := make([]string, BackupCodeCount)
for i := 0; i < BackupCodeCount; i++ {
code, err := generateRandomBackupCode()
if err != nil {
return nil, err
}
codes[i] = code
}
return codes, nil
}
// generateRandomBackupCode 生成单个备用码
func generateRandomBackupCode() (string, error) {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
code := make([]byte, BackupCodeLength)
for i := range code {
randomBytes := make([]byte, 1)
_, err := rand.Read(randomBytes)
if err != nil {
return "", err
}
code[i] = charset[int(randomBytes[0])%len(charset)]
}
// 格式化为 XXXX-XXXX 格式
return fmt.Sprintf("%s-%s", string(code[:4]), string(code[4:])), nil
}
// ValidateBackupCode 验证备用码格式
func ValidateBackupCode(code string) bool {
// 移除所有分隔符并转为大写
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
if len(cleanCode) != BackupCodeLength {
return false
}
// 检查字符是否合法
for _, char := range cleanCode {
if !((char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
return false
}
}
return true
}
// NormalizeBackupCode 标准化备用码格式
func NormalizeBackupCode(code string) string {
cleanCode := strings.ToUpper(strings.ReplaceAll(code, "-", ""))
if len(cleanCode) == BackupCodeLength {
return fmt.Sprintf("%s-%s", cleanCode[:4], cleanCode[4:])
}
return code
}
// HashBackupCode 对备用码进行哈希
func HashBackupCode(code string) (string, error) {
normalizedCode := NormalizeBackupCode(code)
return Password2Hash(normalizedCode)
}
// Get2FAIssuer 获取2FA发行者名称
func Get2FAIssuer() string {
return SystemName
}
// getEnvOrDefault 获取环境变量或默认值
func getEnvOrDefault(key, defaultValue string) string {
if value, exists := os.LookupEnv(key); exists {
return value
}
return defaultValue
}
// ValidateNumericCode 验证数字验证码格式
func ValidateNumericCode(code string) (string, error) {
// 移除空格
code = strings.ReplaceAll(code, " ", "")
if len(code) != 6 {
return "", fmt.Errorf("验证码必须是6位数字")
}
// 检查是否为纯数字
if _, err := strconv.Atoi(code); err != nil {
return "", fmt.Errorf("验证码只能包含数字")
}
return code, nil
}
// GenerateQRCodeData 生成二维码数据
func GenerateQRCodeData(secret, username string) string {
issuer := Get2FAIssuer()
accountName := fmt.Sprintf("%s (%s)", username, issuer)
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&digits=6&period=30",
issuer, accountName, secret, issuer)
}

View File

@@ -49,6 +49,7 @@ const (
ChannelTypeCoze = 49
ChannelTypeKling = 50
ChannelTypeJimeng = 51
ChannelTypeVidu = 52
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
@@ -106,4 +107,5 @@ var ChannelBaseURLs = []string{
"https://api.coze.cn", //49
"https://api.klingai.com", //50
"https://visual.volcengineapi.com", //51
"https://api.vidu.cn", //52
}

View File

@@ -69,6 +69,12 @@ func testChannel(channel *model.Channel, testModel string) testResult {
newAPIError: nil,
}
}
if channel.Type == constant.ChannelTypeVidu {
return testResult{
localErr: errors.New("vidu channel test is not supported"),
newAPIError: nil,
}
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
@@ -203,7 +209,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeDoRequestFailed),
newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
}
}
var httpResp *http.Response
@@ -214,7 +220,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeBadResponse),
newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
}
}
}
@@ -230,7 +236,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
return testResult{
context: c,
localErr: errors.New("usage is nil"),
newAPIError: types.NewError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody),
newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
}
}
usage := usageA.(*dto.Usage)
@@ -240,7 +246,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
return testResult{
context: c,
localErr: err,
newAPIError: types.NewError(err, types.ErrorCodeReadResponseBodyFailed),
newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
}
}
info.PromptTokens = usage.PromptTokens
@@ -326,8 +332,11 @@ func TestChannel(c *gin.Context) {
}
channel, err := model.CacheGetChannel(channelId)
if err != nil {
common.ApiError(c, err)
return
channel, err = model.GetChannelById(channelId, true)
if err != nil {
common.ApiError(c, err)
return
}
}
//defer func() {
// if channel.ChannelInfo.IsMultiKey {
@@ -411,7 +420,7 @@ func testAllChannels(notify bool) error {
if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
if milliseconds > disableThreshold {
err := errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
newAPIError = types.NewError(err, types.ErrorCodeChannelResponseTimeExceeded)
newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
shouldBanChannel = true
}
}

View File

@@ -36,11 +36,30 @@ type OpenAIModel struct {
Parent string `json:"parent"`
}
type GoogleOpenAICompatibleModels []struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description,omitempty"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
MaxTemperature int `json:"maxTemperature,omitempty"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
type GoogleOpenAICompatibleResponse struct {
Models []GoogleOpenAICompatibleModels `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
@@ -52,6 +71,13 @@ func parseStatusFilter(statusParam string) int {
}
}
func clearChannelInfo(channel *model.Channel) {
if channel.ChannelInfo.IsMultiKey {
channel.ChannelInfo.MultiKeyDisabledReason = nil
channel.ChannelInfo.MultiKeyDisabledTime = nil
}
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
@@ -126,6 +152,10 @@ func GetAllChannels(c *gin.Context) {
}
}
for _, datum := range channelData {
clearChannelInfo(datum)
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
@@ -168,26 +198,59 @@ func FetchUpstreamModels(c *gin.Context) {
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
var url string
switch channel.Type {
case constant.ChannelTypeGemini:
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL)
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
default:
url = fmt.Sprintf("%s/v1/models", baseURL)
}
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
var body []byte
if channel.Type == constant.ChannelTypeGemini {
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
} else {
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
}
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
common.ApiError(c, err)
return
}
var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
var parseSuccess bool
// 适配特殊格式
switch channel.Type {
case constant.ChannelTypeGemini:
var googleResult GoogleOpenAICompatibleResponse
if err = json.Unmarshal(body, &googleResult); err == nil {
// 转换Google格式到OpenAI格式
for _, model := range googleResult.Models {
for _, gModel := range model {
result.Data = append(result.Data, OpenAIModel{
ID: gModel.Name,
})
}
}
parseSuccess = true
}
}
// 如果解析失败尝试OpenAI格式
if !parseSuccess {
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
}
var ids []string
@@ -319,6 +382,10 @@ func SearchChannels(c *gin.Context) {
pagedData := channelData[startIdx:endIdx]
for _, datum := range pagedData {
clearChannelInfo(datum)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -342,6 +409,9 @@ func GetChannel(c *gin.Context) {
common.ApiError(c, err)
return
}
if channel != nil {
clearChannelInfo(channel)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -669,6 +739,7 @@ func DeleteChannelBatch(c *gin.Context) {
type PatchChannel struct {
model.Channel
MultiKeyMode *string `json:"multi_key_mode"`
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
}
func UpdateChannel(c *gin.Context) {
@@ -688,7 +759,7 @@ func UpdateChannel(c *gin.Context) {
return
}
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
originChannel, err := model.GetChannelById(channel.Id, false)
originChannel, err := model.GetChannelById(channel.Id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -704,6 +775,69 @@ func UpdateChannel(c *gin.Context) {
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
}
// 处理多key模式下的密钥追加/覆盖逻辑
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
switch *channel.KeyMode {
case "append":
// 追加模式:将新密钥添加到现有密钥列表
if originChannel.Key != "" {
var newKeys []string
var existingKeys []string
// 解析现有密钥
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
// JSON数组格式
var arr []json.RawMessage
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
existingKeys = make([]string, len(arr))
for i, v := range arr {
existingKeys[i] = string(v)
}
}
} else {
// 换行分隔格式
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
}
// 处理 Vertex AI 的特殊情况
if channel.Type == constant.ChannelTypeVertexAi {
// 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "追加密钥解析失败: " + err.Error(),
})
return
}
newKeys = array
} else {
// 单个JSON密钥
newKeys = []string{channel.Key}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
} else {
// 普通渠道的处理
inputKeys := strings.Split(channel.Key, "\n")
for _, key := range inputKeys {
key = strings.TrimSpace(key)
if key != "" {
newKeys = append(newKeys, key)
}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
}
case "replace":
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
}
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
@@ -711,6 +845,7 @@ func UpdateChannel(c *gin.Context) {
}
model.InitChannelCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
@@ -914,3 +1049,409 @@ func CopyChannel(c *gin.Context) {
// success
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
}
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
Action string `json:"action"` // "disable_key", "enable_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key and enable_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
}
// MultiKeyStatusResponse represents the response for key status query
type MultiKeyStatusResponse struct {
Keys []KeyStatus `json:"keys"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
// Statistics
EnabledCount int `json:"enabled_count"`
ManualDisabledCount int `json:"manual_disabled_count"`
AutoDisabledCount int `json:"auto_disabled_count"`
}
type KeyStatus struct {
Index int `json:"index"`
Status int `json:"status"` // 1: enabled, 2: disabled
DisabledTime int64 `json:"disabled_time,omitempty"`
Reason string `json:"reason,omitempty"`
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
}
// ManageMultiKeys handles multi-key management operations
func ManageMultiKeys(c *gin.Context) {
request := MultiKeyManageRequest{}
err := c.ShouldBindJSON(&request)
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(request.ChannelId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "渠道不存在",
})
return
}
if !channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该渠道不是多密钥模式",
})
return
}
switch request.Action {
case "get_key_status":
keys := channel.GetKeys()
// Default pagination parameters
page := request.Page
pageSize := request.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 50 // Default page size
}
// Statistics for all keys (unchanged by filtering)
var enabledCount, manualDisabledCount, autoDisabledCount int
// Build all key status data first
var allKeyStatusList []KeyStatus
for i, key := range keys {
status := 1 // default enabled
var disabledTime int64
var reason string
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// Count for statistics (all keys)
switch status {
case 1:
enabledCount++
case 2:
manualDisabledCount++
case 3:
autoDisabledCount++
}
if status != 1 {
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
}
}
// Create key preview (first 10 chars)
keyPreview := key
if len(key) > 10 {
keyPreview = key[:10] + "..."
}
allKeyStatusList = append(allKeyStatusList, KeyStatus{
Index: i,
Status: status,
DisabledTime: disabledTime,
Reason: reason,
KeyPreview: keyPreview,
})
}
// Apply status filter if specified
var filteredKeyStatusList []KeyStatus
if request.Status != nil {
for _, keyStatus := range allKeyStatusList {
if keyStatus.Status == *request.Status {
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
}
}
} else {
filteredKeyStatusList = allKeyStatusList
}
// Calculate pagination based on filtered results
filteredTotal := len(filteredKeyStatusList)
totalPages := (filteredTotal + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
if page > totalPages {
page = totalPages
}
// Calculate range for current page
start := (page - 1) * pageSize
end := start + pageSize
if end > filteredTotal {
end = filteredTotal
}
// Get the page data
var pageKeyStatusList []KeyStatus
if start < filteredTotal {
pageKeyStatusList = filteredKeyStatusList[start:end]
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": MultiKeyStatusResponse{
Keys: pageKeyStatusList,
Total: filteredTotal, // Total of filtered results
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
EnabledCount: enabledCount, // Overall statistics
ManualDisabledCount: manualDisabledCount, // Overall statistics
AutoDisabledCount: autoDisabledCount, // Overall statistics
},
})
return
case "disable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要禁用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已禁用",
})
return
case "enable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要启用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
if channel.ChannelInfo.MultiKeyStatusList != nil {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已启用",
})
return
case "enable_all_keys":
// 清空所有禁用状态,使所有密钥回到默认启用状态
var enabledCount int
if channel.ChannelInfo.MultiKeyStatusList != nil {
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
}
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
})
return
case "disable_all_keys":
// 禁用所有启用的密钥
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
var disabledCount int
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
status := 1 // default enabled
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
// 只禁用当前启用的密钥
if status == 1 {
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
disabledCount++
}
}
if disabledCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有可禁用的密钥",
})
return
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
})
return
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string
var deletedCount int
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
status := 1 // default enabled
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// 只删除自动禁用status == 3的密钥保留启用status == 1和手动禁用status == 2的密钥
if status == 3 {
deletedCount++
} else {
remainingKeys = append(remainingKeys, key)
// 保留非自动禁用密钥的状态信息,重新索引
if status != 1 {
newStatusList[newIndex] = status
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
}
newIndex++
}
}
if deletedCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有需要删除的自动禁用密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
"data": deletedCount,
})
return
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的操作",
})
return
}
}

View File

@@ -28,19 +28,19 @@ func Playground(c *gin.Context) {
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied)
newAPIError = types.NewError(errors.New("暂不支持使用 access token"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
return
}
if playgroundRequest.Model == "" {
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest)
newAPIError = types.NewError(errors.New("请选择模型"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
return
}
c.Set("original_model", playgroundRequest.Model)
@@ -51,7 +51,7 @@ func Playground(c *gin.Context) {
group = userGroup
} else {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied)
newAPIError = types.NewError(errors.New("无权访问该分组"), types.ErrorCodeAccessDenied, types.ErrOptionWithSkipRetry())
return
}
c.Set("group", group)
@@ -62,7 +62,7 @@ func Playground(c *gin.Context) {
// Write user context to ensure acceptUnsetRatio is available
userCache, err := model.GetUserCache(userId)
if err != nil {
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError)
newAPIError = types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
return
}
userCache.WriteContext(c)

View File

@@ -6,6 +6,7 @@ import (
"one-api/common"
"one-api/model"
"strconv"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
@@ -63,7 +64,7 @@ func AddRedemption(c *gin.Context) {
common.ApiError(c, err)
return
}
if len(redemption.Name) == 0 || len(redemption.Name) > 20 {
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "兑换码名称长度必须在1-20之间",

View File

@@ -47,7 +47,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
err = relay.TextHelper(c)
}
if constant2.ErrorLogEnabled && err != nil {
if constant2.ErrorLogEnabled && err != nil && types.IsRecordErrorLog(err) {
// 保存错误日志到mysql中
userId := c.GetInt("id")
tokenName := c.GetString("token_name")
@@ -56,14 +56,21 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.ErrorType
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
other["channel_name"] = c.GetString("channel_name")
other["channel_type"] = c.GetInt("channel_type")
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.Error(), tokenId, 0, false, userGroup, other)
adminInfo := make(map[string]interface{})
adminInfo["use_channel"] = c.GetStringSlice("use_channel")
isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
if isMultiKey {
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
}
other["admin_info"] = adminInfo
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
}
return err
@@ -128,7 +135,7 @@ func WssRelay(c *gin.Context) {
defer ws.Close()
if err != nil {
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed).ToOpenAIError())
helper.WssError(c, ws, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()).ToOpenAIError())
return
}
@@ -259,10 +266,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
}
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
if channel == nil {
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed)
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {
@@ -278,7 +285,7 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
if types.IsChannelError(openaiErr) {
return true
}
if types.IsLocalError(openaiErr) {
if types.IsSkipRetryError(openaiErr) {
return false
}
if retryTimes <= 0 {

View File

@@ -83,7 +83,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
taskResult := &relaycommon.TaskInfo{}
// try parse as New API response format
var responseItems dto.TaskResponse[model.Task]
if err = json.Unmarshal(responseBody, &responseItems); err == nil {
if err = json.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() {
t := responseItems.Data
taskResult.TaskID = t.TaskID
taskResult.Status = string(t.Status)

553
controller/twofa.go Normal file
View File

@@ -0,0 +1,553 @@
package controller
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// Setup2FARequest 设置2FA请求结构
type Setup2FARequest struct {
Code string `json:"code" binding:"required"`
}
// Verify2FARequest 验证2FA请求结构
type Verify2FARequest struct {
Code string `json:"code" binding:"required"`
}
// Setup2FAResponse 设置2FA响应结构
type Setup2FAResponse struct {
Secret string `json:"secret"`
QRCodeData string `json:"qr_code_data"`
BackupCodes []string `json:"backup_codes"`
}
// Setup2FA 初始化2FA设置
func Setup2FA(c *gin.Context) {
userId := c.GetInt("id")
// 检查用户是否已经启用2FA
existing, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if existing != nil && existing.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已启用2FA请先禁用后重新设置",
})
return
}
// 如果存在已禁用的2FA记录先删除它
if existing != nil && !existing.IsEnabled {
if err := existing.Delete(); err != nil {
common.ApiError(c, err)
return
}
existing = nil // 重置为nil后续将创建新记录
}
// 获取用户信息
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
// 生成TOTP密钥
key, err := common.GenerateTOTPSecret(user.Username)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成2FA密钥失败",
})
common.SysError("生成TOTP密钥失败: " + err.Error())
return
}
// 生成备用码
backupCodes, err := common.GenerateBackupCodes()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成备用码失败",
})
common.SysError("生成备用码失败: " + err.Error())
return
}
// 生成二维码数据
qrCodeData := common.GenerateQRCodeData(key.Secret(), user.Username)
// 创建或更新2FA记录暂未启用
twoFA := &model.TwoFA{
UserId: userId,
Secret: key.Secret(),
IsEnabled: false,
}
if existing != nil {
// 更新现有记录
twoFA.Id = existing.Id
err = twoFA.Update()
} else {
// 创建新记录
err = twoFA.Create()
}
if err != nil {
common.ApiError(c, err)
return
}
// 创建备用码记录
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "保存备用码失败",
})
common.SysError("保存备用码失败: " + err.Error())
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "开始设置两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "2FA设置初始化成功请使用认证器扫描二维码并输入验证码完成设置",
"data": Setup2FAResponse{
Secret: key.Secret(),
QRCodeData: qrCodeData,
BackupCodes: backupCodes,
},
})
}
// Enable2FA 启用2FA
func Enable2FA(c *gin.Context) {
var req Setup2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请先完成2FA初始化设置",
})
return
}
if twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "2FA已经启用",
})
return
}
// 验证TOTP验证码
cleanCode, err := common.ValidateNumericCode(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !common.ValidateTOTPCode(twoFA.Secret, cleanCode) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 启用2FA
if err := twoFA.Enable(); err != nil {
common.ApiError(c, err)
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "成功启用两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "两步验证启用成功",
})
}
// Disable2FA 禁用2FA
func Disable2FA(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码或备用码
cleanCode, err := common.ValidateNumericCode(req.Code)
isValidTOTP := false
isValidBackup := false
if err == nil {
// 尝试验证TOTP
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
}
if !isValidTOTP {
// 尝试验证备用码
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
if !isValidTOTP && !isValidBackup {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 禁用2FA
if err := model.DisableTwoFA(userId); err != nil {
common.ApiError(c, err)
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "禁用两步验证")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "两步验证已禁用",
})
}
// Get2FAStatus 获取用户2FA状态
func Get2FAStatus(c *gin.Context) {
userId := c.GetInt("id")
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
status := map[string]interface{}{
"enabled": false,
"locked": false,
}
if twoFA != nil {
status["enabled"] = twoFA.IsEnabled
status["locked"] = twoFA.IsLocked()
if twoFA.IsEnabled {
// 获取剩余备用码数量
backupCount, err := model.GetUnusedBackupCodeCount(userId)
if err != nil {
common.SysError("获取备用码数量失败: " + err.Error())
} else {
status["backup_codes_remaining"] = backupCount
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": status,
})
}
// RegenerateBackupCodes 重新生成备用码
func RegenerateBackupCodes(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码
cleanCode, err := common.ValidateNumericCode(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
valid, err := twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if !valid {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 生成新的备用码
backupCodes, err := common.GenerateBackupCodes()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成备用码失败",
})
common.SysError("生成备用码失败: " + err.Error())
return
}
// 保存新的备用码
if err := model.CreateBackupCodes(userId, backupCodes); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "保存备用码失败",
})
common.SysError("保存备用码失败: " + err.Error())
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, "重新生成两步验证备用码")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "备用码重新生成成功",
"data": map[string]interface{}{
"backup_codes": backupCodes,
},
})
}
// Verify2FALogin 登录时验证2FA
func Verify2FALogin(c *gin.Context) {
var req Verify2FARequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
// 从会话中获取pending用户信息
session := sessions.Default(c)
pendingUserId := session.Get("pending_user_id")
if pendingUserId == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "会话已过期,请重新登录",
})
return
}
userId, ok := pendingUserId.(int)
if !ok {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "会话数据无效,请重新登录",
})
return
}
// 获取用户信息
user, err := model.GetUserById(userId, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户不存在",
})
return
}
// 获取2FA记录
twoFA, err := model.GetTwoFAByUserId(user.Id)
if err != nil {
common.ApiError(c, err)
return
}
if twoFA == nil || !twoFA.IsEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
// 验证TOTP验证码或备用码
cleanCode, err := common.ValidateNumericCode(req.Code)
isValidTOTP := false
isValidBackup := false
if err == nil {
// 尝试验证TOTP
isValidTOTP, _ = twoFA.ValidateTOTPAndUpdateUsage(cleanCode)
}
if !isValidTOTP {
// 尝试验证备用码
isValidBackup, err = twoFA.ValidateBackupCodeAndUpdateUsage(req.Code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
}
if !isValidTOTP && !isValidBackup {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码或备用码错误,请重试",
})
return
}
// 2FA验证成功清理pending会话信息并完成登录
session.Delete("pending_username")
session.Delete("pending_user_id")
session.Save()
setupLogin(user, c)
}
// Admin2FAStats 管理员获取2FA统计信息
func Admin2FAStats(c *gin.Context) {
stats, err := model.GetTwoFAStats()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": stats,
})
}
// AdminDisable2FA 管理员强制禁用用户2FA
func AdminDisable2FA(c *gin.Context) {
userIdStr := c.Param("id")
userId, err := strconv.Atoi(userIdStr)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户ID格式错误",
})
return
}
// 检查目标用户权限
targetUser, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
myRole := c.GetInt("role")
if myRole <= targetUser.Role && myRole != common.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权操作同级或更高级用户的2FA设置",
})
return
}
// 禁用2FA
if err := model.DisableTwoFA(userId); err != nil {
if errors.Is(err, model.ErrTwoFANotEnabled) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户未启用2FA",
})
return
}
common.ApiError(c, err)
return
}
// 记录操作日志
adminId := c.GetInt("id")
model.RecordLog(userId, model.LogTypeManage,
fmt.Sprintf("管理员(ID:%d)强制禁用了用户的两步验证", adminId))
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "用户2FA已被强制禁用",
})
}

View File

@@ -62,6 +62,32 @@ func Login(c *gin.Context) {
})
return
}
// 检查是否启用2FA
if model.IsTwoFAEnabled(user.Id) {
// 设置pending session等待2FA验证
session := sessions.Default(c)
session.Set("pending_username", user.Username)
session.Set("pending_user_id", user.Id)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"message": "无法保存会话信息,请重试",
"success": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"message": "请输入两步验证码",
"success": true,
"data": map[string]interface{}{
"require_2fa": true,
},
})
return
}
setupLogin(&user, c)
}

View File

@@ -1,7 +1,9 @@
package dto
type ChannelSettings struct {
ForceFormat bool `json:"force_format,omitempty"`
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
Proxy string `json:"proxy"`
ForceFormat bool `json:"force_format,omitempty"`
ThinkingToContent bool `json:"thinking_to_content,omitempty"`
Proxy string `json:"proxy"`
PassThroughBodyEnabled bool `json:"pass_through_body_enabled,omitempty"`
SystemPrompt string `json:"system_prompt,omitempty"`
}

View File

@@ -2,6 +2,7 @@ package dto
import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/types"
)
@@ -284,14 +285,9 @@ func (c *ClaudeRequest) ParseSystem() []ClaudeMediaMessage {
return mediaContent
}
type ClaudeError struct {
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
}
type ClaudeErrorWithStatusCode struct {
Error ClaudeError `json:"error"`
StatusCode int `json:"status_code"`
Error types.ClaudeError `json:"error"`
StatusCode int `json:"status_code"`
LocalError bool
}
@@ -303,7 +299,7 @@ type ClaudeResponse struct {
Completion string `json:"completion,omitempty"`
StopReason string `json:"stop_reason,omitempty"`
Model string `json:"model,omitempty"`
Error *types.ClaudeError `json:"error,omitempty"`
Error any `json:"error,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *ClaudeMediaMessage `json:"content_block,omitempty"`
@@ -324,6 +320,42 @@ func (c *ClaudeResponse) GetIndex() int {
return *c.Index
}
// GetClaudeError 从动态错误类型中提取ClaudeError结构
func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError {
if c.Error == nil {
return nil
}
switch err := c.Error.(type) {
case types.ClaudeError:
return &err
case *types.ClaudeError:
return err
case map[string]interface{}:
// 处理从JSON解析来的map结构
claudeErr := &types.ClaudeError{}
if errType, ok := err["type"].(string); ok {
claudeErr.Type = errType
}
if errMsg, ok := err["message"].(string); ok {
claudeErr.Message = errMsg
}
return claudeErr
case string:
// 处理简单字符串错误
return &types.ClaudeError{
Type: "error",
Message: err,
}
default:
// 未知类型,尝试转换为字符串
return &types.ClaudeError{
Type: "unknown_error",
Message: fmt.Sprintf("%v", err),
}
}
}
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`

View File

@@ -1,6 +1,9 @@
package gemini
package dto
import "encoding/json"
import (
"encoding/json"
"one-api/common"
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
@@ -32,7 +35,7 @@ func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
MimeTypeSnake string `json:"mime_type"`
}
if err := json.Unmarshal(data, &aux); err != nil {
if err := common.Unmarshal(data, &aux); err != nil {
return err
}
@@ -53,7 +56,7 @@ type FunctionCall struct {
Arguments any `json:"args"`
}
type FunctionResponse struct {
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response map[string]interface{} `json:"response"`
}
@@ -78,7 +81,7 @@ type GeminiPart struct {
Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
@@ -93,7 +96,7 @@ func (p *GeminiPart) UnmarshalJSON(data []byte) error {
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
}
if err := json.Unmarshal(data, &aux); err != nil {
if err := common.Unmarshal(data, &aux); err != nil {
return err
}

View File

@@ -7,15 +7,15 @@ import (
)
type ResponseFormat struct {
Type string `json:"type,omitempty"`
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
Type string `json:"type,omitempty"`
JsonSchema json.RawMessage `json:"json_schema,omitempty"`
}
type FormatJsonSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema any `json:"schema,omitempty"`
Strict any `json:"strict,omitempty"`
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema any `json:"schema,omitempty"`
Strict json.RawMessage `json:"strict,omitempty"`
}
type GeneralOpenAIRequest struct {
@@ -73,6 +73,15 @@ func (r *GeneralOpenAIRequest) ToMap() map[string]any {
return result
}
func (r *GeneralOpenAIRequest) GetSystemRoleName() string {
if strings.HasPrefix(r.Model, "o") {
if !strings.HasPrefix(r.Model, "o1-mini") && !strings.HasPrefix(r.Model, "o1-preview") {
return "developer"
}
}
return "system"
}
type ToolCallRequest struct {
ID string `json:"id,omitempty"`
Type string `json:"type"`

View File

@@ -2,12 +2,18 @@ package dto
import (
"encoding/json"
"fmt"
"one-api/types"
)
type SimpleResponse struct {
Usage `json:"usage"`
Error *OpenAIError `json:"error"`
Error any `json:"error"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (s *SimpleResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(s.Error)
}
type TextResponse struct {
@@ -31,10 +37,15 @@ type OpenAITextResponse struct {
Object string `json:"object"`
Created any `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Error *types.OpenAIError `json:"error,omitempty"`
Error any `json:"error,omitempty"`
Usage `json:"usage"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (o *OpenAITextResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
type OpenAIEmbeddingResponseItem struct {
Object string `json:"object"`
Index int `json:"index"`
@@ -217,7 +228,7 @@ type OpenAIResponsesResponse struct {
Object string `json:"object"`
CreatedAt int `json:"created_at"`
Status string `json:"status"`
Error *types.OpenAIError `json:"error,omitempty"`
Error any `json:"error,omitempty"`
IncompleteDetails *IncompleteDetails `json:"incomplete_details,omitempty"`
Instructions string `json:"instructions"`
MaxOutputTokens int `json:"max_output_tokens"`
@@ -237,6 +248,11 @@ type OpenAIResponsesResponse struct {
Metadata json.RawMessage `json:"metadata"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func (o *OpenAIResponsesResponse) GetOpenAIError() *types.OpenAIError {
return GetOpenAIError(o.Error)
}
type IncompleteDetails struct {
Reasoning string `json:"reasoning"`
}
@@ -276,3 +292,45 @@ type ResponsesStreamResponse struct {
Delta string `json:"delta,omitempty"`
Item *ResponsesOutput `json:"item,omitempty"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构
func GetOpenAIError(errorField any) *types.OpenAIError {
if errorField == nil {
return nil
}
switch err := errorField.(type) {
case types.OpenAIError:
return &err
case *types.OpenAIError:
return err
case map[string]interface{}:
// 处理从JSON解析来的map结构
openaiErr := &types.OpenAIError{}
if errType, ok := err["type"].(string); ok {
openaiErr.Type = errType
}
if errMsg, ok := err["message"].(string); ok {
openaiErr.Message = errMsg
}
if errParam, ok := err["param"].(string); ok {
openaiErr.Param = errParam
}
if errCode, ok := err["code"]; ok {
openaiErr.Code = errCode
}
return openaiErr
case string:
// 处理简单字符串错误
return &types.OpenAIError{
Type: "error",
Message: err,
}
default:
// 未知类型,尝试转换为字符串
return &types.OpenAIError{
Type: "unknown_error",
Message: fmt.Sprintf("%v", err),
}
}
}

2
go.mod
View File

@@ -45,6 +45,7 @@ require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
github.com/aws/smithy-go v1.20.2 // indirect
github.com/boombuler/barcode v1.1.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -79,6 +80,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.1 // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect

6
go.sum
View File

@@ -20,6 +20,10 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/boombuler/barcode v1.1.0 h1:ChaYjBR63fr4LFyGn8E8nt7dBSt3MiU3zMOZqFvVkHo=
github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0=
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
@@ -169,6 +173,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=

View File

@@ -585,6 +585,19 @@
"渠道权重": "渠道权重",
"渠道额外设置": "渠道额外设置",
"此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:": "此项可选,用于配置渠道特定设置,为一个 JSON 字符串,例如:",
"强制格式化": "强制格式化",
"强制将响应格式化为 OpenAI 标准格式只适用于OpenAI渠道类型": "强制将响应格式化为 OpenAI 标准格式只适用于OpenAI渠道类型",
"思考内容转换": "思考内容转换",
"将 reasoning_content 转换为 <think> 标签拼接到内容中": "将 reasoning_content 转换为 <think> 标签拼接到内容中",
"透传请求体": "透传请求体",
"启用请求体透传功能": "启用请求体透传功能",
"代理地址": "代理地址",
"例如: socks5://user:pass@host:port": "例如: socks5://user:pass@host:port",
"用于配置网络代理": "用于配置网络代理",
"用于配置网络代理,支持 socks5 协议": "用于配置网络代理,支持 socks5 协议",
"系统提示词": "系统提示词",
"输入系统提示词,用户的系统提示词将优先于此设置": "输入系统提示词,用户的系统提示词将优先于此设置",
"用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置": "用户优先:如果用户在请求中指定了系统提示词,将优先使用用户的设置",
"参数覆盖": "参数覆盖",
"此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:": "此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:",
"请输入组织org-xxx": "请输入组织org-xxx",

View File

@@ -122,6 +122,7 @@ func authHelper(c *gin.Context, minRole int) {
c.Set("role", role)
c.Set("id", id)
c.Set("group", session.Get("group"))
c.Set("user_group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
//userCache, err := model.GetUserCache(id.(int))

View File

@@ -100,6 +100,10 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
if modelRequest.Model == "" {
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
return
}
var selectGroup string
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
@@ -107,18 +111,17 @@ func Distribute() func(c *gin.Context) {
if userGroup == "auto" {
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
}
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败distributor: %s", showGroup, modelRequest.Model, err.Error())
message := fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败(数据库一致性已被破坏,distributor: %s", showGroup, modelRequest.Model, err.Error())
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
//if channel != nil {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员"
//}
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 可用渠道不存在(数据库一致性已被破坏,distributor", userGroup, modelRequest.Model))
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 可用渠道distributor", userGroup, modelRequest.Model))
return
}
}
@@ -244,7 +247,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError {
c.Set("original_model", modelName) // for retry
if channel == nil {
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed)
return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id)
common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name)
@@ -266,6 +269,9 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel.ChannelInfo.IsMultiKey {
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true)
common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index)
} else {
// 必须设置为 false否则在重试到单个 key 的时候会导致日志显示错误
common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false)
}
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
common.SetContextKey(c, constant.ContextKeyChannelKey, key)

View File

@@ -136,7 +136,7 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
}
}
} else {
return nil, errors.New("channel not found")
return nil, nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
@@ -284,6 +284,21 @@ func FixAbility() (int, int, error) {
return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
}
defer fixLock.Unlock()
// truncate abilities table
if common.UsingSQLite {
err := DB.Exec("DELETE FROM abilities").Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
return 0, 0, err
}
} else {
err := DB.Exec("TRUNCATE TABLE abilities").Error
if err != nil {
common.SysError(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
return 0, 0, err
}
}
var channels []*Channel
// Find all channels
err := DB.Model(&Channel{}).Find(&channels).Error

View File

@@ -41,19 +41,25 @@ type Channel struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
AutoBan *int `json:"auto_ban" gorm:"default:1"`
OtherInfo string `json:"other_info"`
Settings string `json:"settings"`
Tag *string `json:"tag" gorm:"index"`
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
ParamOverride *string `json:"param_override" gorm:"type:text"`
// add after v0.8.5
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
// cache info
Keys []string `json:"-" gorm:"-"`
}
type ChannelInfo struct {
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式
MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量
MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表key index -> status
MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表key index -> reason
MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表key index -> time
MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
}
// Value implements driver.Valuer interface
@@ -67,15 +73,18 @@ func (c *ChannelInfo) Scan(value interface{}) error {
return common.Unmarshal(bytesValue, c)
}
func (channel *Channel) getKeys() []string {
func (channel *Channel) GetKeys() []string {
if channel.Key == "" {
return []string{}
}
if len(channel.Keys) > 0 {
return channel.Keys
}
trimmed := strings.TrimSpace(channel.Key)
// If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios)
if strings.HasPrefix(trimmed, "[") {
var arr []json.RawMessage
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
res := make([]string, len(arr))
for i, v := range arr {
res[i] = string(v)
@@ -95,7 +104,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
}
// Obtain all keys (split by \n)
keys := channel.getKeys()
keys := channel.GetKeys()
if len(keys) == 0 {
// No keys available, return error, should disable the channel
return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey)
@@ -138,7 +147,7 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) {
channelInfo, err := CacheGetChannelInfo(channel.Id)
if err != nil {
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed)
return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry())
}
//println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex)
defer func() {
@@ -197,7 +206,7 @@ func (channel *Channel) GetGroups() []string {
func (channel *Channel) GetOtherInfo() map[string]interface{} {
otherInfo := make(map[string]interface{})
if channel.OtherInfo != "" {
err := json.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo)
if err != nil {
common.SysError("failed to unmarshal other info: " + err.Error())
}
@@ -425,7 +434,7 @@ func (channel *Channel) Update() error {
trimmed := strings.TrimSpace(keyStr)
if strings.HasPrefix(trimmed, "[") {
var arr []json.RawMessage
if err := json.Unmarshal([]byte(trimmed), &arr); err == nil {
if err := common.Unmarshal([]byte(trimmed), &arr); err == nil {
keys = make([]string, len(arr))
for i, v := range arr {
keys[i] = string(v)
@@ -522,8 +531,8 @@ func CleanupChannelPollingLocks() {
})
}
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
keys := channel.getKeys()
func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) {
keys := channel.GetKeys()
if len(keys) == 0 {
channel.Status = status
} else {
@@ -541,6 +550,14 @@ func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int) {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
} else {
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason
channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp()
}
if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize {
channel.Status = common.ChannelStatusAutoDisabled
@@ -563,7 +580,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
}
if channelCache.ChannelInfo.IsMultiKey {
// 如果是多Key模式更新缓存中的状态
handlerMultiKeyUpdate(channelCache, usingKey, status)
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
//CacheUpdateChannel(channelCache)
//return true
} else {
@@ -571,10 +588,6 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if channelCache.Status == status {
return false
}
// 如果缓存渠道不存在(说明已经被禁用),且要设置的状态不为启用,直接返回
if status != common.ChannelStatusEnabled {
return false
}
CacheUpdateChannelStatus(channelId, status)
}
}
@@ -598,7 +611,7 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
if channel.ChannelInfo.IsMultiKey {
beforeStatus := channel.Status
handlerMultiKeyUpdate(channel, usingKey, status)
handlerMultiKeyUpdate(channel, usingKey, status, reason)
if beforeStatus != channel.Status {
shouldUpdateAbilities = true
}
@@ -778,7 +791,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
func (channel *Channel) ValidateSettings() error {
channelParams := &dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), channelParams)
err := common.Unmarshal([]byte(*channel.Setting), channelParams)
if err != nil {
return err
}
@@ -789,7 +802,7 @@ func (channel *Channel) ValidateSettings() error {
func (channel *Channel) GetSetting() dto.ChannelSettings {
setting := dto.ChannelSettings{}
if channel.Setting != nil && *channel.Setting != "" {
err := json.Unmarshal([]byte(*channel.Setting), &setting)
err := common.Unmarshal([]byte(*channel.Setting), &setting)
if err != nil {
common.SysError("failed to unmarshal setting: " + err.Error())
channel.Setting = nil // 清空设置以避免后续错误
@@ -800,7 +813,7 @@ func (channel *Channel) GetSetting() dto.ChannelSettings {
}
func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
settingBytes, err := json.Marshal(setting)
settingBytes, err := common.Marshal(setting)
if err != nil {
common.SysError("failed to marshal setting: " + err.Error())
return
@@ -811,7 +824,7 @@ func (channel *Channel) SetSetting(setting dto.ChannelSettings) {
func (channel *Channel) GetParamOverride() map[string]interface{} {
paramOverride := make(map[string]interface{})
if channel.ParamOverride != nil && *channel.ParamOverride != "" {
err := json.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
err := common.Unmarshal([]byte(*channel.ParamOverride), &paramOverride)
if err != nil {
common.SysError("failed to unmarshal param override: " + err.Error())
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"math/rand"
"one-api/common"
"one-api/constant"
"one-api/setting"
"sort"
"strings"
@@ -66,6 +67,20 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
//channelsIDM = newChannelId2channel
for i, channel := range newChannelId2channel {
if channel.ChannelInfo.IsMultiKey {
channel.Keys = channel.GetKeys()
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
if oldChannel, ok := channelsIDM[i]; ok {
// 存在旧的渠道如果是多key且轮询保留轮询索引信息
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
}
}
}
}
}
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
@@ -130,7 +145,7 @@ func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
channels := group2model2channels[group][model]
if len(channels) == 0 {
return nil, errors.New("channel not found")
return nil, nil
}
if len(channels) == 1 {
@@ -203,9 +218,6 @@ func CacheGetChannel(id int) (*Channel, error) {
if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return c, nil
}
@@ -224,9 +236,6 @@ func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !ok {
return nil, fmt.Errorf("渠道# %d已不存在", id)
}
if c.Status != common.ChannelStatusEnabled {
return nil, fmt.Errorf("渠道# %d已被禁用", id)
}
return &c.ChannelInfo, nil
}
@@ -239,6 +248,20 @@ func CacheUpdateChannelStatus(id int, status int) {
if channel, ok := channelsIDM[id]; ok {
channel.Status = status
}
if status != common.ChannelStatusEnabled {
// delete the channel from group2model2channels
for group, model2channels := range group2model2channels {
for model, channels := range model2channels {
for i, channelId := range channels {
if channelId == id {
// remove the channel from the slice
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
break
}
}
}
}
}
}
func CacheUpdateChannel(channel *Channel) {

View File

@@ -251,6 +251,8 @@ func migrateDB() error {
&QuotaData{},
&Task{},
&Setup{},
&TwoFA{},
&TwoFABackupCode{},
)
if err != nil {
return err
@@ -277,6 +279,8 @@ func migrateDBFast() error {
{&QuotaData{}, "QuotaData"},
{&Task{}, "Task"},
{&Setup{}, "Setup"},
{&TwoFA{}, "TwoFA"},
{&TwoFABackupCode{}, "TwoFABackupCode"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))

322
model/twofa.go Normal file
View File

@@ -0,0 +1,322 @@
package model
import (
"errors"
"fmt"
"one-api/common"
"time"
"gorm.io/gorm"
)
var ErrTwoFANotEnabled = errors.New("用户未启用2FA")
// TwoFA 用户2FA设置表
type TwoFA struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"unique;not null;index"`
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥不返回给前端
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
LockedUntil *time.Time `json:"locked_until,omitempty"`
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// TwoFABackupCode 备用码使用记录表
type TwoFABackupCode struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"not null;index"`
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
IsUsed bool `json:"is_used" gorm:"default:false"`
UsedAt *time.Time `json:"used_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
}
// GetTwoFAByUserId 根据用户ID获取2FA设置
func GetTwoFAByUserId(userId int) (*TwoFA, error) {
if userId == 0 {
return nil, errors.New("用户ID不能为空")
}
var twoFA TwoFA
err := DB.Where("user_id = ?", userId).First(&twoFA).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil // 返回nil表示未设置2FA
}
return nil, err
}
return &twoFA, nil
}
// IsTwoFAEnabled 检查用户是否启用了2FA
func IsTwoFAEnabled(userId int) bool {
twoFA, err := GetTwoFAByUserId(userId)
if err != nil || twoFA == nil {
return false
}
return twoFA.IsEnabled
}
// CreateTwoFA 创建2FA设置
func (t *TwoFA) Create() error {
// 检查用户是否已存在2FA设置
existing, err := GetTwoFAByUserId(t.UserId)
if err != nil {
return err
}
if existing != nil {
return errors.New("用户已存在2FA设置")
}
// 验证用户存在
var user User
if err := DB.First(&user, t.UserId).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("用户不存在")
}
return err
}
return DB.Create(t).Error
}
// Update 更新2FA设置
func (t *TwoFA) Update() error {
if t.Id == 0 {
return errors.New("2FA记录ID不能为空")
}
return DB.Save(t).Error
}
// Delete 删除2FA设置
func (t *TwoFA) Delete() error {
if t.Id == 0 {
return errors.New("2FA记录ID不能为空")
}
// 使用事务确保原子性
return DB.Transaction(func(tx *gorm.DB) error {
// 同时删除相关的备用码记录(硬删除)
if err := tx.Unscoped().Where("user_id = ?", t.UserId).Delete(&TwoFABackupCode{}).Error; err != nil {
return err
}
// 硬删除2FA记录
return tx.Unscoped().Delete(t).Error
})
}
// ResetFailedAttempts 重置失败尝试次数
func (t *TwoFA) ResetFailedAttempts() error {
t.FailedAttempts = 0
t.LockedUntil = nil
return t.Update()
}
// IncrementFailedAttempts 增加失败尝试次数
func (t *TwoFA) IncrementFailedAttempts() error {
t.FailedAttempts++
// 检查是否需要锁定
if t.FailedAttempts >= common.MaxFailAttempts {
lockUntil := time.Now().Add(time.Duration(common.LockoutDuration) * time.Second)
t.LockedUntil = &lockUntil
}
return t.Update()
}
// IsLocked 检查账户是否被锁定
func (t *TwoFA) IsLocked() bool {
if t.LockedUntil == nil {
return false
}
return time.Now().Before(*t.LockedUntil)
}
// CreateBackupCodes 创建备用码
func CreateBackupCodes(userId int, codes []string) error {
return DB.Transaction(func(tx *gorm.DB) error {
// 先删除现有的备用码
if err := tx.Where("user_id = ?", userId).Delete(&TwoFABackupCode{}).Error; err != nil {
return err
}
// 创建新的备用码记录
for _, code := range codes {
hashedCode, err := common.HashBackupCode(code)
if err != nil {
return err
}
backupCode := TwoFABackupCode{
UserId: userId,
CodeHash: hashedCode,
IsUsed: false,
}
if err := tx.Create(&backupCode).Error; err != nil {
return err
}
}
return nil
})
}
// ValidateBackupCode 验证并使用备用码
func ValidateBackupCode(userId int, code string) (bool, error) {
if !common.ValidateBackupCode(code) {
return false, errors.New("验证码或备用码不正确")
}
normalizedCode := common.NormalizeBackupCode(code)
// 查找未使用的备用码
var backupCodes []TwoFABackupCode
if err := DB.Where("user_id = ? AND is_used = false", userId).Find(&backupCodes).Error; err != nil {
return false, err
}
// 验证备用码
for _, bc := range backupCodes {
if common.ValidatePasswordAndHash(normalizedCode, bc.CodeHash) {
// 标记为已使用
now := time.Now()
bc.IsUsed = true
bc.UsedAt = &now
if err := DB.Save(&bc).Error; err != nil {
return false, err
}
return true, nil
}
}
return false, nil
}
// GetUnusedBackupCodeCount 获取未使用的备用码数量
func GetUnusedBackupCodeCount(userId int) (int, error) {
var count int64
err := DB.Model(&TwoFABackupCode{}).Where("user_id = ? AND is_used = false", userId).Count(&count).Error
return int(count), err
}
// DisableTwoFA 禁用用户的2FA
func DisableTwoFA(userId int) error {
twoFA, err := GetTwoFAByUserId(userId)
if err != nil {
return err
}
if twoFA == nil {
return ErrTwoFANotEnabled
}
// 删除2FA设置和备用码
return twoFA.Delete()
}
// EnableTwoFA 启用2FA
func (t *TwoFA) Enable() error {
t.IsEnabled = true
t.FailedAttempts = 0
t.LockedUntil = nil
return t.Update()
}
// ValidateTOTPAndUpdateUsage 验证TOTP并更新使用记录
func (t *TwoFA) ValidateTOTPAndUpdateUsage(code string) (bool, error) {
// 检查是否被锁定
if t.IsLocked() {
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
}
// 验证TOTP码
if !common.ValidateTOTPCode(t.Secret, code) {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
// 验证成功,重置失败次数并更新最后使用时间
now := time.Now()
t.FailedAttempts = 0
t.LockedUntil = nil
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
}
// ValidateBackupCodeAndUpdateUsage 验证备用码并更新使用记录
func (t *TwoFA) ValidateBackupCodeAndUpdateUsage(code string) (bool, error) {
// 检查是否被锁定
if t.IsLocked() {
return false, fmt.Errorf("账户已被锁定,请在%v后重试", t.LockedUntil.Format("2006-01-02 15:04:05"))
}
// 验证备用码
valid, err := ValidateBackupCode(t.UserId, code)
if err != nil {
return false, err
}
if !valid {
// 增加失败次数
if err := t.IncrementFailedAttempts(); err != nil {
common.SysError("更新2FA失败次数失败: " + err.Error())
}
return false, nil
}
// 验证成功,重置失败次数并更新最后使用时间
now := time.Now()
t.FailedAttempts = 0
t.LockedUntil = nil
t.LastUsedAt = &now
if err := t.Update(); err != nil {
common.SysError("更新2FA使用记录失败: " + err.Error())
}
return true, nil
}
// GetTwoFAStats 获取2FA统计信息管理员使用
func GetTwoFAStats() (map[string]interface{}, error) {
var totalUsers, enabledUsers int64
// 总用户数
if err := DB.Model(&User{}).Count(&totalUsers).Error; err != nil {
return nil, err
}
// 启用2FA的用户数
if err := DB.Model(&TwoFA{}).Where("is_enabled = true").Count(&enabledUsers).Error; err != nil {
return nil, err
}
enabledRate := float64(0)
if totalUsers > 0 {
enabledRate = float64(enabledUsers) / float64(totalUsers) * 100
}
return map[string]interface{}{
"total_users": totalUsers,
"enabled_users": enabledUsers,
"enabled_rate": fmt.Sprintf("%.1f%%", enabledRate),
}, nil
}

View File

@@ -62,7 +62,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
promptTokens := 0
@@ -75,7 +75,7 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -90,18 +90,18 @@ func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)

View File

@@ -26,6 +26,7 @@ type Adaptor interface {
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error)
}
type TaskAdaptor interface {

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {

View File

@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {

View File

@@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
@@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{

View File

@@ -223,7 +223,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
helper.SetEventStreamHeaders(c)
// 处理流式请求的 ping 保活
generalSettings := operation_setting.GetGeneralSetting()
if generalSettings.PingIntervalEnabled {
if generalSettings.PingIntervalEnabled && !info.DisablePing {
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
stopPinger = startPingKeepAlive(c, pingInterval)
// 使用defer确保在任何情况下都能停止ping goroutine

View File

@@ -22,6 +22,11 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
@@ -43,15 +48,15 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
keyParts := strings.Split(info.ApiKey, "|")
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 0 || keyParts[0] == "" {
return errors.New("invalid API key: authorization token is required")
}
if len(keyParts) > 1 {
if keyParts[1] != "" {
req.Set("appid", keyParts[1])
}
}
return errors.New("invalid API key: authorization token is required")
}
if len(keyParts) > 1 {
if keyParts[1] != "" {
req.Set("appid", keyParts[1])
}
}
req.Set("Authorization", "Bearer "+keyParts[0])
return nil
}

View File

@@ -24,6 +24,11 @@ type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil
}

View File

@@ -612,8 +612,8 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
common.SysError("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
@@ -704,8 +704,8 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
if requestMode == RequestModeCompletion {
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *common.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
// ConvertAudioRequest implements channel.Adaptor.
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")

View File

@@ -19,6 +19,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -24,6 +24,11 @@ type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -1,14 +1,13 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/setting/model_setting"
@@ -21,10 +20,33 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
if len(request.Contents) > 0 {
for i, content := range request.Contents {
if i == 0 {
if request.Contents[0].Role == "" {
request.Contents[0].Role = "user"
}
}
for _, part := range content.Parts {
if part.FileData != nil {
if part.FileData.MimeType == "" && strings.Contains(part.FileData.FileUri, "www.youtube.com") {
part.FileData.MimeType = "video/webm"
}
}
}
}
}
return request, nil
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{}
oaiReq, err := adaptor.ConvertClaudeRequest(c, info, req)
if err != nil {
return nil, err
}
return a.ConvertOpenAIRequest(c, info, oaiReq.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -49,13 +71,13 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
geminiRequest := dto.GeminiImageRequest{
Instances: []dto.GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
Parameters: dto.GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
@@ -136,9 +158,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
}
// only process the first input
geminiRequest := GeminiEmbeddingRequest{
Content: GeminiChatContent{
Parts: []GeminiPart{
geminiRequest := dto.GeminiEmbeddingRequest{
Content: dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{
Text: inputs[0],
},
@@ -171,6 +193,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if info.IsStream {
info.DisablePing = true
return GeminiTextGenerationStreamHandler(c, info, resp)
} else {
return GeminiTextGenerationHandler(c, info, resp)
@@ -208,60 +231,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -1,6 +1,7 @@
package gemini
import (
"github.com/pkg/errors"
"io"
"net/http"
"one-api/common"
@@ -20,7 +21,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if common.DebugEnabled {
@@ -28,10 +29,10 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
}
// 解析为 Gemini 原生响应格式
var geminiResponse GeminiChatResponse
var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 计算使用量(基于 UsageMetadata
@@ -54,7 +55,7 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := common.Marshal(geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
@@ -71,7 +72,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
@@ -110,10 +111,14 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
if err != nil {
common.LogError(c, err.Error())
}
info.SendResponseCount++
return true
})
if info.SendResponseCount == 0 {
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
}
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258

View File

@@ -9,6 +9,7 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -80,7 +81,7 @@ func clampThinkingBudget(modelName string, budget int) int {
return budget
}
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
@@ -92,7 +93,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
if len(parts) == 2 && parts[1] != "" {
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
@@ -112,11 +113,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
}
if isUnsupported {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
}
} else {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
IncludeThoughts: true,
}
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
@@ -127,7 +128,7 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
}
} else if strings.HasSuffix(modelName, "-nothinking") {
if !isNew25Pro {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
geminiRequest.GenerationConfig.ThinkingConfig = &dto.GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
}
@@ -136,11 +137,11 @@ func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayIn
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
GenerationConfig: GeminiChatGenerationConfig{
geminiRequest := dto.GeminiChatRequest{
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
GenerationConfig: dto.GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
@@ -157,9 +158,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
ThinkingAdaptor(&geminiRequest, info)
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
safetySettings := make([]dto.GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
safetySettings = append(safetySettings, dto.GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
@@ -197,17 +198,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
functions = append(functions, tool.Function)
}
if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
geminiRequest.Tools = append(geminiRequest.Tools, dto.GeminiChatTool{
FunctionDeclarations: functions,
})
}
@@ -219,9 +220,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
if len(textRequest.ResponseFormat.JsonSchema) > 0 {
// 先将json.RawMessage解析
var jsonSchema dto.FormatJsonSchema
if err := common.Unmarshal(textRequest.ResponseFormat.JsonSchema, &jsonSchema); err == nil {
cleanedSchema := removeAdditionalPropertiesWithDepth(jsonSchema.Schema, 0)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}
}
tool_call_ids := make(map[string]string)
@@ -233,7 +238,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
continue
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
geminiRequest.Contents = append(geminiRequest.Contents, dto.GeminiChatContent{
Role: "user",
})
}
@@ -260,18 +265,18 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
}
functionResp := &FunctionResponse{
functionResp := &dto.GeminiFunctionResponse{
Name: name,
Response: contentMap,
}
*parts = append(*parts, GeminiPart{
*parts = append(*parts, dto.GeminiPart{
FunctionResponse: functionResp,
})
continue
}
var parts []GeminiPart
content := GeminiChatContent{
var parts []dto.GeminiPart
content := dto.GeminiChatContent{
Role: message.Role,
}
// isToolCall := false
@@ -285,8 +290,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
toolCall := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: call.Function.Name,
Arguments: args,
},
@@ -303,7 +308,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if part.Text == "" {
continue
}
parts = append(parts, GeminiPart{
parts = append(parts, dto.GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
@@ -326,8 +331,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: fileData.MimeType, // 使用原始的 MimeType因为大小写可能对API有意义
Data: fileData.Base64Data,
},
@@ -337,8 +342,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: format,
Data: base64String,
},
@@ -352,8 +357,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: format,
Data: base64String,
},
@@ -366,8 +371,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: "audio/" + part.GetInputAudio().Format,
Data: base64String,
},
@@ -387,8 +392,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
}
if len(system_content) > 0 {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
geminiRequest.SystemInstructions = &dto.GeminiChatContent{
Parts: []dto.GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
@@ -631,7 +636,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
return data
}
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -653,7 +658,7 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
}
}
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
Object: "chat.completion",
@@ -720,10 +725,9 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dt
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false
hasImage := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true
@@ -732,7 +736,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
//Role: "assistant",
},
}
var texts []string
@@ -754,7 +758,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
hasImage = true
}
} else if part.FunctionCall != nil {
isTools = true
@@ -791,28 +794,59 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Choices = choices
return &response, isStop, hasImage
return &response, isStop
}
func handleStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
streamData, err := common.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}
err = openai.HandleStreamFormat(c, info, string(streamData), info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
return fmt.Errorf("failed to handle stream format: %w", err)
}
return nil
}
func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.ChatCompletionsStreamResponse) error {
streamData, err := common.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}
openai.HandleFinalResponse(c, info, string(streamData), resp.Id, resp.Created, resp.Model, resp.GetSystemFingerprint(), resp.Usage, false)
return nil
}
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
// responseText := ""
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
responseText := strings.Builder{}
var usage = &dto.Usage{}
var imageCount int
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
if hasImage {
imageCount++
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
}
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
@@ -829,18 +863,30 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
}
}
}
err = helper.ObjectData(c, response)
if info.SendResponseCount == 0 {
// send first response
err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
if err != nil {
common.LogError(c, err.Error())
}
}
err = handleStream(c, info, response)
if err != nil {
common.LogError(c, err.Error())
}
if isStop {
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
helper.ObjectData(c, response)
_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
}
return true
})
var response *dto.ChatCompletionsStreamResponse
if info.SendResponseCount == 0 {
// 空补全,报错不计费
// empty response, throw an error
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
}
if imageCount != 0 {
if usage.CompletionTokens == 0 {
@@ -851,14 +897,24 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
if info.ShouldIncludeUsage {
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
helper.Done(c)
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
// helper.Done(c)
//}
//resp.Body.Close()
return usage, nil
}
@@ -866,19 +922,19 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse GeminiChatResponse
var geminiResponse dto.GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
@@ -915,12 +971,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var geminiResponse GeminiEmbeddingResponse
var geminiResponse dto.GeminiEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
@@ -950,9 +1006,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse dto.GeminiImageResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
@@ -13,11 +12,18 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}

View File

@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
var jimengResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// Check if the response indicates an error

View File

@@ -19,6 +19,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -17,10 +17,21 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
openaiAdaptor := openai.Adaptor{}
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
if err != nil {
return nil, err
}
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -37,6 +48,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return info.BaseUrl + "/v1/chat/completions", nil
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embed", nil
@@ -55,7 +69,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Ollama(*request)
return requestOpenAI2Ollama(request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -76,11 +90,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
usage, err = ollamaEmbeddingHandler(c, info, resp)
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
usage, err = ollamaEmbeddingHandler(c, info, resp)
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/gin-gonic/gin"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
@@ -92,15 +92,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
@@ -121,7 +121,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
}
doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil

View File

@@ -34,6 +34,15 @@ type Adaptor struct {
ResponseFormat string
}
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
// 使用 service.GeminiToOpenAIRequest 转换请求格式
openaiRequest, err := service.GeminiToOpenAIRequest(request, info)
if err != nil {
return nil, err
}
return a.ConvertOpenAIRequest(c, info, openaiRequest)
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
//if !strings.Contains(request.Model, "claude") {
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
@@ -64,7 +73,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == relayconstant.RelayModeRealtime {

View File

@@ -2,6 +2,8 @@ package openai
import (
"encoding/json"
"errors"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
@@ -14,20 +16,23 @@ import (
)
// 辅助函数
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
case relaycommon.RelayFormatGemini:
return handleGeminiFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
@@ -41,6 +46,36 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
return nil
}
func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
common.LogError(c, "failed to unmarshal stream response: "+err.Error())
return err
}
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
// 如果返回 nil表示没有实际内容跳过发送
if geminiResponse == nil {
return nil
}
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.LogError(c, "failed to marshal gemini response: "+err.Error())
return err
}
// send gemini format response
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
} else {
return errors.New("streaming error: flusher not found")
}
return nil
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -158,7 +193,7 @@ func handleLastResponse(lastStreamData string, responseId *string, createAt *int
return nil
}
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
@@ -174,7 +209,7 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
@@ -183,7 +218,38 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
_ = helper.ClaudeData(c, *resp)
}
case relaycommon.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
// 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
// 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空finishReason 为 STOP 的响应
// 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
// 暂不知是否有程序会不兼容。
geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
// openai 流响应开头的空数据
if geminiResponse == nil {
return
}
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
common.SysError("error marshalling gemini response: " + err.Error())
return
}
// 发送最终的 Gemini 响应
c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
}

View File

@@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
defer common.CloseResponseBodyGracefully(resp)
@@ -123,30 +123,19 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
var thinkToContent bool
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if info.ChannelSetting.ThinkingToContent {
thinkToContent = true
}
var (
lastStreamData string
)
var lastStreamData string
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
common.SysError("error handling stream format: " + err.Error())
}
}
lastStreamData = data
streamItems = append(streamItems, data)
if len(data) > 0 {
lastStreamData = data
streamItems = append(streamItems, data)
}
return true
})
@@ -154,16 +143,18 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
common.SysError("error handling last response: " + err.Error())
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if shouldSendLastResp {
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.SysError("error processing tokens: " + err.Error())
common.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
@@ -176,8 +167,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
}
}
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
return usage, nil
}
@@ -188,14 +178,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
forceFormat := false
@@ -233,6 +223,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
case relaycommon.RelayFormatGemini:
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
geminiRespStr, err := common.Marshal(geminiResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = geminiRespStr
}
common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -273,7 +270,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -557,13 +554,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 写入新的 response body

View File

@@ -22,14 +22,14 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if responsesResponse.Error != nil {
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
if oaiError := responsesResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
}
// 写入新的 response body

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -17,6 +17,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -18,20 +18,24 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
return nil, errors.New("not supported")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
adaptor := openai.Adaptor{}
return adaptor.ConvertImageRequest(c, info, request)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -47,7 +51,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -81,16 +85,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayMode {
case constant.RelayModeRerank:
usage, err = siliconflowRerankHandler(c, info, resp)
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
fallthrough
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -15,13 +15,13 @@ import (
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,

View File

@@ -0,0 +1,285 @@
package vidu
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"github.com/pkg/errors"
)
// ============================
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct {
Model string `json:"model"`
Images []string `json:"images"`
Prompt string `json:"prompt,omitempty"`
Duration int `json:"duration,omitempty"`
Seed int `json:"seed,omitempty"`
Resolution string `json:"resolution,omitempty"`
MovementAmplitude string `json:"movement_amplitude,omitempty"`
Bgm bool `json:"bgm,omitempty"`
Payload string `json:"payload,omitempty"`
CallbackUrl string `json:"callback_url,omitempty"`
}
type responsePayload struct {
TaskId string `json:"task_id"`
State string `json:"state"`
Model string `json:"model"`
Images []string `json:"images"`
Prompt string `json:"prompt"`
Duration int `json:"duration"`
Seed int `json:"seed"`
Resolution string `json:"resolution"`
Bgm bool `json:"bgm"`
MovementAmplitude string `json:"movement_amplitude"`
Payload string `json:"payload"`
CreatedAt string `json:"created_at"`
}
type taskResultResponse struct {
State string `json:"state"`
ErrCode string `json:"err_code"`
Credits int `json:"credits"`
Payload string `json:"payload"`
Creations []creation `json:"creations"`
}
type creation struct {
ID string `json:"id"`
URL string `json:"url"`
CoverURL string `json:"cover_url"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {
var req SubmitReq
if err := c.ShouldBindJSON(&req); err != nil {
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
}
if req.Prompt == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
}
if req.Image != "" {
info.Action = constant.TaskActionGenerate
} else {
info.Action = constant.TaskActionTextGenerate
}
c.Set("task_request", req)
return nil
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, err
}
if len(body.Images) == 0 {
c.Set("action", constant.TaskActionTextGenerate)
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
var path string
switch info.Action {
case constant.TaskActionGenerate:
path = "/img2video"
default:
path = "/text2video"
}
return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil
}
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Token "+info.ApiKey)
return nil
}
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
if action := c.GetString("action"); action != "" {
info.Action = action
}
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
var vResp responsePayload
err = json.Unmarshal(responseBody, &vResp)
if err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrap(err, fmt.Sprintf("%s", responseBody)), "unmarshal_response_failed", http.StatusInternalServerError)
return
}
if vResp.State == "failed" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task failed"), "task_failed", http.StatusBadRequest)
return
}
c.JSON(http.StatusOK, vResp)
return vResp.TaskId, responseBody, nil
}
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
url := fmt.Sprintf("%s/ent/v2/tasks/%s/creations", baseUrl, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Token "+key)
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"viduq1", "vidu2.0", "vidu1.5"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "vidu"
}
// ============================
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
var images []string
if req.Image != "" {
images = []string{req.Image}
}
r := requestPayload{
Model: defaultString(req.Model, "viduq1"),
Images: images,
Prompt: req.Prompt,
Duration: defaultInt(req.Duration, 5),
Resolution: defaultString(req.Size, "1080p"),
MovementAmplitude: "auto",
Bgm: false,
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func defaultString(value, defaultValue string) string {
if value == "" {
return defaultValue
}
return value
}
func defaultInt(value, defaultValue int) int {
if value == 0 {
return defaultValue
}
return value
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
taskInfo := &relaycommon.TaskInfo{}
var taskResp taskResultResponse
err := json.Unmarshal(respBody, &taskResp)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
state := taskResp.State
switch state {
case "created", "queueing":
taskInfo.Status = model.TaskStatusSubmitted
case "processing":
taskInfo.Status = model.TaskStatusInProgress
case "success":
taskInfo.Status = model.TaskStatusSuccess
if len(taskResp.Creations) > 0 {
taskInfo.Url = taskResp.Creations[0].URL
}
case "failed":
taskInfo.Status = model.TaskStatusFailure
if taskResp.ErrCode != "" {
taskInfo.Reason = taskResp.ErrCode
}
default:
return nil, fmt.Errorf("unknown task state: %s", state)
}
return taskInfo, nil
}

View File

@@ -25,6 +25,11 @@ type Adaptor struct {
Timestamp int64
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if tencentSb.Response.Error.Code != 0 {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -44,6 +44,11 @@ type Adaptor struct {
AccountCredentials Credentials
}
func (a *Adaptor) ConvertGeminiRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeminiChatRequest) (any, error) {
geminiAdaptor := gemini.Adaptor{}
return geminiAdaptor.ConvertGeminiRequest(c, info, request)
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
@@ -67,10 +72,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
} else {
a.RequestMode = RequestModeGemini
}
}
@@ -83,6 +88,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
@@ -100,6 +106,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else {
suffix = "generateContent"
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict"
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
@@ -231,6 +242,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return gemini.GeminiImageHandler(c, info, resp)
}
usage, err = gemini.GeminiChatHandler(c, info, resp)
}
case RequestModeLlama:

View File

@@ -36,7 +36,12 @@ var Cache = asynccache.NewAsyncCache(asynccache.Options{
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
var cacheKey string
if info.ChannelIsMultiKey {
cacheKey = fmt.Sprintf("access-token-%d-%d", info.ChannelId, info.ChannelMultiKeyIndex)
} else {
cacheKey = fmt.Sprintf("access-token-%d", info.ChannelId)
}
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil

View File

@@ -23,6 +23,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -19,6 +19,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
//panic("implement me")

View File

@@ -17,6 +17,11 @@ type Adaptor struct {
request *dto.GeneralOpenAIRequest
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -16,6 +16,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if !zhipuResponse.Success {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -18,6 +18,11 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")

View File

@@ -40,7 +40,7 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateClaudeRequest(c)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if textRequest.Stream {
@@ -49,18 +49,18 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed)
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(textRequest.MaxTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
@@ -77,10 +77,9 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if textRequest.MaxTokens == 0 {
textRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
@@ -108,18 +107,41 @@ func ClaudeHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo.UpstreamModelName = textRequest.Model
}
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertClaudeRequest(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
requestBody = bytes.NewBuffer(jsonData)
}
jsonData, err := common.Marshal(convertedRequest)
if common.DebugEnabled {
println("requestBody: ", string(jsonData))
}
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
requestBody = bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response

View File

@@ -60,17 +60,19 @@ type ResponsesUsageInfo struct {
}
type RelayInfo struct {
ChannelType int
ChannelId int
TokenId int
TokenKey string
UserId int
UsingGroup string // 使用的分组
UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
ChannelType int
ChannelId int
ChannelIsMultiKey bool // 是否多密钥
ChannelMultiKeyIndex int // 多密钥索引
TokenId int
TokenKey string
UserId int
UsingGroup string // 使用的分组
UserGroup string // 用户所在分组
TokenUnlimited bool
StartTime time.Time
FirstResponseTime time.Time
isFirstResponse bool
//SendLastReasoningResponse bool
ApiType int
IsStream bool
@@ -88,6 +90,7 @@ type RelayInfo struct {
BaseUrl string
SupportStreamOptions bool
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
IsModelMapped bool
ClientWs *websocket.Conn
TargetWs *websocket.Conn
@@ -259,6 +262,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
},
ChannelIsMultiKey: common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey),
ChannelMultiKeyIndex: common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex),
}
if strings.HasPrefix(c.Request.URL.Path, "/pg") {
info.IsPlayground = true

View File

@@ -16,7 +16,7 @@ import (
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
@@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var xinRerankResponse xinference.XinRerankResponse
err = common.Unmarshal(responseBody, &xinRerankResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
@@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} else {
err = common.Unmarshal(responseBody, &jinaResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}

View File

@@ -41,17 +41,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptToken := getEmbeddingPromptToken(*embeddingRequest)
@@ -59,7 +59,7 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -74,18 +74,17 @@ func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -2,9 +2,9 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
@@ -20,8 +20,8 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
request := &gemini.GeminiChatRequest{}
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error) {
request := &dto.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
@@ -44,7 +44,7 @@ func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
// }
}
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
func checkGeminiInputSensitive(textRequest *dto.GeminiChatRequest) ([]string, error) {
var inputTexts []string
for _, content := range textRequest.Contents {
for _, part := range content.Parts {
@@ -61,7 +61,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
return sensitiveWords, err
}
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
func getGeminiInputTokens(req *dto.GeminiChatRequest, info *relaycommon.RelayInfo) int {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
@@ -78,9 +78,13 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
return inputTokens
}
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
func isNoThinkingRequest(req *dto.GeminiChatRequest) bool {
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
configBudget := req.GenerationConfig.ThinkingConfig.ThinkingBudget
if configBudget != nil && *configBudget == 0 {
// 如果思考预算为 0则认为是非思考请求
return true
}
}
return false
}
@@ -109,7 +113,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateGeminiRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
relayInfo := relaycommon.GenRelayInfoGemini(c)
@@ -121,14 +125,14 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
// model mapped 模型映射
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
@@ -159,7 +163,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
@@ -175,7 +179,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
@@ -194,19 +198,47 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
}
requestBody, err := json.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewReader(body)
} else {
// 使用 ConvertGeminiRequest 转换请求格式
convertedRequest, err := adaptor.ConvertGeminiRequest(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
if common.DebugEnabled {
println("Gemini request body: %s", string(jsonData))
}
requestBody = bytes.NewReader(jsonData)
}
if common.DebugEnabled {
println("Gemini request body: %s", string(requestBody))
}
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
common.LogError(c, "Do gemini request failed: "+err.Error())
return types.NewError(err, types.ErrorCodeDoRequestFailed)
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -139,6 +139,24 @@ func GetLocalRealtimeID(c *gin.Context) string {
return fmt.Sprintf("evt_%s", logID)
}
func GenerateStartEmptyResponse(id string, createAt int64, model string, systemFingerprint *string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: systemFingerprint,
Choices: []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
Content: common.GetPointer(""),
},
},
},
}
}
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,

View File

@@ -54,7 +54,7 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
)
generalSettings := operation_setting.GetGeneralSetting()
pingEnabled := generalSettings.PingIntervalEnabled
pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
if pingInterval <= 0 {
pingInterval = DefaultPingInterval
@@ -234,6 +234,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-stopChan:
return
}
} else {
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
if common.DebugEnabled {
println("received [DONE], stopping scanner")
}
return
}
}

View File

@@ -16,6 +16,7 @@ import (
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/setting/model_setting"
"one-api/types"
"strings"
@@ -114,17 +115,17 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
imageRequest, err := getAndValidImageRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
var preConsumedQuota int
var quota int
@@ -172,37 +173,58 @@ func ImageHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return types.NewError(err, types.ErrorCodeQueryDataError)
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota-quota < 0 {
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota)
return types.NewError(fmt.Errorf("image pre-consumed quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota)), types.ErrorCodeInsufficientUserQuota, types.ErrOptionWithSkipRetry())
}
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(jsonData)
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if relayInfo.RelayMode == relayconstant.RelayModeImagesEdits {
requestBody = convertedRequest.(io.Reader)
} else {
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if common.DebugEnabled {
println(fmt.Sprintf("image request body: %s", requestBody))
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
if common.DebugEnabled {
println(fmt.Sprintf("image request body: %s", string(jsonData)))
}
requestBody = bytes.NewBuffer(jsonData)
}
}
statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
@@ -91,9 +90,8 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
// get & validate textRequest 获取并验证文本请求
textRequest, err := getAndValidateTextRequest(c, relayInfo)
if err != nil {
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if textRequest.WebSearchOptions != nil {
@@ -104,13 +102,13 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
words, err := checkRequestSensitive(textRequest, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
// 获取 promptTokens如果上下文中已经存在则直接使用
@@ -122,14 +120,14 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed)
return types.NewError(err, types.ErrorCodeCountTokenFailed, types.ErrOptionWithSkipRetry())
}
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptTokens, int(math.Max(float64(textRequest.MaxTokens), float64(textRequest.MaxCompletionTokens))))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
@@ -166,25 +164,49 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest)
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
if common.DebugEnabled {
println("requestBody: ", string(body))
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIRequest(c, relayInfo, textRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := json.Marshal(convertedRequest)
if relayInfo.ChannelSetting.SystemPrompt != "" {
// 如果有系统提示,则将其添加到请求中
request := convertedRequest.(*dto.GeneralOpenAIRequest)
containSystemPrompt := false
for _, message := range request.Messages {
if message.Role == request.GetSystemRoleName() {
containSystemPrompt = true
break
}
}
if !containSystemPrompt {
// 如果没有系统提示,则添加系统提示
systemMessage := dto.Message{
Role: request.GetSystemRoleName(),
Content: relayInfo.ChannelSetting.SystemPrompt,
}
request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
}
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
@@ -196,7 +218,7 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
@@ -208,7 +230,6 @@ func TextHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
var httpResp *http.Response
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
@@ -281,13 +302,13 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *types.NewAPIError) {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError)
return 0, 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota <= 0 {
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
if userQuota-preConsumedQuota < 0 {
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(fmt.Errorf("pre-consume quota failed, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
relayInfo.UserQuota = userQuota
if userQuota > 100*preConsumedQuota {
@@ -311,11 +332,11 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if preConsumedQuota > 0 {
err := service.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden)
return 0, 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError)
return 0, 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
}
return preConsumedQuota, userQuota, nil
@@ -494,6 +515,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
} else {
if !ratio.IsZero() && quota == 0 {
quota = 1
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -27,6 +27,7 @@ import (
taskjimeng "one-api/relay/channel/task/jimeng"
"one-api/relay/channel/task/kling"
"one-api/relay/channel/task/suno"
taskVidu "one-api/relay/channel/task/vidu"
"one-api/relay/channel/tencent"
"one-api/relay/channel/vertex"
"one-api/relay/channel/volcengine"
@@ -122,6 +123,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
return &kling.TaskAdaptor{}
case constant.ChannelTypeJimeng:
return &taskjimeng.TaskAdaptor{}
case constant.ChannelTypeVidu:
return &taskVidu.TaskAdaptor{}
}
}
return nil

View File

@@ -3,12 +3,14 @@ package relay
import (
"bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"github.com/gin-gonic/gin"
@@ -29,21 +31,21 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
err := common.UnmarshalBodyReusable(c, &rerankRequest)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
relayInfo := relaycommon.GenRelayInfoRerank(c, rerankRequest)
if rerankRequest.Query == "" {
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest)
return types.NewError(fmt.Errorf("query is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
if len(rerankRequest.Documents) == 0 {
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest)
return types.NewError(fmt.Errorf("documents is empty"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
promptToken := getRerankPromptToken(*rerankRequest)
@@ -51,7 +53,7 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -66,22 +68,46 @@ func RerankHelper(c *gin.Context, relayMode int) (newAPIError *types.NewAPIError
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
requestBody := bytes.NewBuffer(jsonData)
if common.DebugEnabled {
println(fmt.Sprintf("Rerank request body: %s", requestBody.String()))
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || relayInfo.ChannelSetting.PassThroughBodyEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
_ = common.Unmarshal(jsonData, &reqMap)
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = common.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
if common.DebugEnabled {
println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
}
requestBody = bytes.NewBuffer(jsonData)
}
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)

View File

@@ -51,7 +51,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
req, err := getAndValidateResponsesRequest(c)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidateResponsesRequest error: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
relayInfo := relaycommon.GenRelayInfoResponses(c, req)
@@ -60,13 +60,13 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
sensitiveWords, err := checkInputSensitive(req, relayInfo)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected)
return types.NewError(err, types.ErrorCodeSensitiveWordsDetected, types.ErrOptionWithSkipRetry())
}
}
err = helper.ModelMappedHelper(c, relayInfo, req)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
if value, exists := c.Get("prompt_tokens"); exists {
@@ -79,7 +79,7 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.MaxOutputTokens))
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre consume quota
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
@@ -93,38 +93,38 @@ func ResponsesHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}()
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
body, err := common.GetRequestBody(c)
if err != nil {
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed)
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
}
requestBody = bytes.NewBuffer(body)
} else {
convertedRequest, err := adaptor.ConvertOpenAIResponsesRequest(c, relayInfo, *req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(relayInfo.ParamOverride) > 0 {
reqMap := make(map[string]interface{})
err = json.Unmarshal(jsonData, &reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid)
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
for key, value := range relayInfo.ParamOverride {
reqMap[key] = value
}
jsonData, err = json.Marshal(reqMap)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
}

View File

@@ -24,12 +24,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
err := helper.ModelMappedHelper(c, relayInfo, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
@@ -46,7 +46,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIErr
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
//var requestBody io.Reader

View File

@@ -44,6 +44,7 @@ func SetApiRouter(router *gin.Engine) {
{
userRoute.POST("/register", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Register)
userRoute.POST("/login", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.Login)
userRoute.POST("/login/2fa", middleware.CriticalRateLimit(), controller.Verify2FALogin)
//userRoute.POST("/tokenlog", middleware.CriticalRateLimit(), controller.TokenLog)
userRoute.GET("/logout", controller.Logout)
userRoute.GET("/epay/notify", controller.EpayNotify)
@@ -66,6 +67,13 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.POST("/stripe/amount", controller.RequestStripeAmount)
selfRoute.POST("/aff_transfer", controller.TransferAffQuota)
selfRoute.PUT("/setting", controller.UpdateUserSetting)
// 2FA routes
selfRoute.GET("/2fa/status", controller.Get2FAStatus)
selfRoute.POST("/2fa/setup", controller.Setup2FA)
selfRoute.POST("/2fa/enable", controller.Enable2FA)
selfRoute.POST("/2fa/disable", controller.Disable2FA)
selfRoute.POST("/2fa/backup_codes", controller.RegenerateBackupCodes)
}
adminRoute := userRoute.Group("/")
@@ -78,6 +86,10 @@ func SetApiRouter(router *gin.Engine) {
adminRoute.POST("/manage", controller.ManageUser)
adminRoute.PUT("/", controller.UpdateUser)
adminRoute.DELETE("/:id", controller.DeleteUser)
// Admin 2FA routes
adminRoute.GET("/2fa/stats", controller.Admin2FAStats)
adminRoute.DELETE("/:id/2fa", controller.AdminDisable2FA)
}
}
optionRoute := apiRouter.Group("/option")
@@ -120,6 +132,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
channelRoute.GET("/tag/models", controller.GetTagModels)
channelRoute.POST("/copy/:id", controller.CopyChannel)
channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -45,7 +45,7 @@ func ShouldDisableChannel(channelType int, err *types.NewAPIError) bool {
if types.IsChannelError(err) {
return true
}
if types.IsLocalError(err) {
if types.IsSkipRetryError(err) {
return false
}
if err.StatusCode == http.StatusUnauthorized {

View File

@@ -188,28 +188,6 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
return &openAIRequest, nil
}
func OpenAIErrorToClaudeError(openAIError *dto.OpenAIErrorWithStatusCode) *dto.ClaudeErrorWithStatusCode {
claudeError := dto.ClaudeError{
Type: "new_api_error",
Message: openAIError.Error.Message,
}
return &dto.ClaudeErrorWithStatusCode{
Error: claudeError,
StatusCode: openAIError.StatusCode,
}
}
func ClaudeErrorToOpenAIError(claudeError *dto.ClaudeErrorWithStatusCode) *dto.OpenAIErrorWithStatusCode {
openAIError := dto.OpenAIError{
Message: claudeError.Error.Message,
Type: "new_api_error",
}
return &dto.OpenAIErrorWithStatusCode{
Error: openAIError,
StatusCode: claudeError.StatusCode,
}
}
func generateStopBlock(index int) *dto.ClaudeResponse {
return &dto.ClaudeResponse{
Type: "content_block_stop",
@@ -251,22 +229,54 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
//resp := &dto.ClaudeResponse{
// Type: "content_block_start",
// ContentBlock: &dto.ClaudeMediaMessage{
// Type: "text",
// Text: common.GetPointer[string](""),
// },
//}
//resp.SetIndex(0)
//claudeResponses = append(claudeResponses, resp)
}
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
})
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
},
})
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
}
return claudeResponses
}
if len(openAIResponse.Choices) == 0 {
// no choices
// TODO: handle this case
// 可能为非标准的 OpenAI 响应,判断是否已经完成
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
oaiUsage := info.ClaudeConvertInfo.Usage
if oaiUsage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: oaiUsage.PromptTokens,
OutputTokens: oaiUsage.CompletionTokens,
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
},
})
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_stop",
})
}
return claudeResponses
} else {
chosenChoice := openAIResponse.Choices[0]
@@ -438,3 +448,353 @@ func toJSONString(v interface{}) string {
}
return string(b)
}
func GeminiToOpenAIRequest(geminiRequest *dto.GeminiChatRequest, info *relaycommon.RelayInfo) (*dto.GeneralOpenAIRequest, error) {
openaiRequest := &dto.GeneralOpenAIRequest{
Model: info.UpstreamModelName,
Stream: info.IsStream,
}
// 转换 messages
var messages []dto.Message
for _, content := range geminiRequest.Contents {
message := dto.Message{
Role: convertGeminiRoleToOpenAI(content.Role),
}
// 处理 parts
var mediaContents []dto.MediaContent
var toolCalls []dto.ToolCallRequest
for _, part := range content.Parts {
if part.Text != "" {
mediaContent := dto.MediaContent{
Type: "text",
Text: part.Text,
}
mediaContents = append(mediaContents, mediaContent)
} else if part.InlineData != nil {
mediaContent := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{
Url: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MimeType, part.InlineData.Data),
Detail: "auto",
MimeType: part.InlineData.MimeType,
},
}
mediaContents = append(mediaContents, mediaContent)
} else if part.FileData != nil {
mediaContent := dto.MediaContent{
Type: "image_url",
ImageUrl: &dto.MessageImageUrl{
Url: part.FileData.FileUri,
Detail: "auto",
MimeType: part.FileData.MimeType,
},
}
mediaContents = append(mediaContents, mediaContent)
} else if part.FunctionCall != nil {
// 处理 Gemini 的工具调用
toolCall := dto.ToolCallRequest{
ID: fmt.Sprintf("call_%d", len(toolCalls)+1), // 生成唯一ID
Type: "function",
Function: dto.FunctionRequest{
Name: part.FunctionCall.FunctionName,
Arguments: toJSONString(part.FunctionCall.Arguments),
},
}
toolCalls = append(toolCalls, toolCall)
} else if part.FunctionResponse != nil {
// 处理 Gemini 的工具响应,创建单独的 tool 消息
toolMessage := dto.Message{
Role: "tool",
ToolCallId: fmt.Sprintf("call_%d", len(toolCalls)), // 使用对应的调用ID
}
toolMessage.SetStringContent(toJSONString(part.FunctionResponse.Response))
messages = append(messages, toolMessage)
}
}
// 设置消息内容
if len(toolCalls) > 0 {
// 如果有工具调用,设置工具调用
message.SetToolCalls(toolCalls)
} else if len(mediaContents) == 1 && mediaContents[0].Type == "text" {
// 如果只有一个文本内容,直接设置字符串
message.Content = mediaContents[0].Text
} else if len(mediaContents) > 0 {
// 如果有多个内容或包含媒体,设置为数组
message.SetMediaContent(mediaContents)
}
// 只有当消息有内容或工具调用时才添加
if len(message.ParseContent()) > 0 || len(message.ToolCalls) > 0 {
messages = append(messages, message)
}
}
openaiRequest.Messages = messages
if geminiRequest.GenerationConfig.Temperature != nil {
openaiRequest.Temperature = geminiRequest.GenerationConfig.Temperature
}
if geminiRequest.GenerationConfig.TopP > 0 {
openaiRequest.TopP = geminiRequest.GenerationConfig.TopP
}
if geminiRequest.GenerationConfig.TopK > 0 {
openaiRequest.TopK = int(geminiRequest.GenerationConfig.TopK)
}
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
openaiRequest.MaxTokens = geminiRequest.GenerationConfig.MaxOutputTokens
}
// gemini stop sequences 最多 5 个openai stop 最多 4 个
if len(geminiRequest.GenerationConfig.StopSequences) > 0 {
openaiRequest.Stop = geminiRequest.GenerationConfig.StopSequences[:4]
}
if geminiRequest.GenerationConfig.CandidateCount > 0 {
openaiRequest.N = geminiRequest.GenerationConfig.CandidateCount
}
// 转换工具调用
if len(geminiRequest.Tools) > 0 {
var tools []dto.ToolCallRequest
for _, tool := range geminiRequest.Tools {
if tool.FunctionDeclarations != nil {
// 将 Gemini 的 FunctionDeclarations 转换为 OpenAI 的 ToolCallRequest
functionDeclarations, ok := tool.FunctionDeclarations.([]dto.FunctionRequest)
if ok {
for _, function := range functionDeclarations {
openAITool := dto.ToolCallRequest{
Type: "function",
Function: dto.FunctionRequest{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
}
tools = append(tools, openAITool)
}
}
}
}
if len(tools) > 0 {
openaiRequest.Tools = tools
}
}
// gemini system instructions
if geminiRequest.SystemInstructions != nil {
// 将系统指令作为第一条消息插入
systemMessage := dto.Message{
Role: "system",
Content: extractTextFromGeminiParts(geminiRequest.SystemInstructions.Parts),
}
openaiRequest.Messages = append([]dto.Message{systemMessage}, openaiRequest.Messages...)
}
return openaiRequest, nil
}
func convertGeminiRoleToOpenAI(geminiRole string) string {
switch geminiRole {
case "user":
return "user"
case "model":
return "assistant"
case "function":
return "function"
default:
return "user"
}
}
func extractTextFromGeminiParts(parts []dto.GeminiPart) string {
var texts []string
for _, part := range parts {
if part.Text != "" {
texts = append(texts, part.Text)
}
}
return strings.Join(texts, "\n")
}
// ResponseOpenAI2Gemini 将 OpenAI 响应转换为 Gemini 格式
func ResponseOpenAI2Gemini(openAIResponse *dto.OpenAITextResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
PromptFeedback: dto.GeminiChatPromptFeedback{
SafetyRatings: []dto.GeminiChatSafetyRating{},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: openAIResponse.PromptTokens,
CandidatesTokenCount: openAIResponse.CompletionTokens,
TotalTokenCount: openAIResponse.PromptTokens + openAIResponse.CompletionTokens,
},
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),
SafetyRatings: []dto.GeminiChatSafetyRating{},
}
// 设置结束原因
var finishReason string
switch choice.FinishReason {
case "stop":
finishReason = "STOP"
case "length":
finishReason = "MAX_TOKENS"
case "content_filter":
finishReason = "SAFETY"
case "tool_calls":
finishReason = "STOP"
default:
finishReason = "STOP"
}
candidate.FinishReason = &finishReason
// 转换消息内容
content := dto.GeminiChatContent{
Role: "model",
Parts: make([]dto.GeminiPart, 0),
}
// 处理工具调用
toolCalls := choice.Message.ParseToolCalls()
if len(toolCalls) > 0 {
for _, toolCall := range toolCalls {
// 解析参数
var args map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
}
} else {
args = make(map[string]interface{})
}
part := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: toolCall.Function.Name,
Arguments: args,
},
}
content.Parts = append(content.Parts, part)
}
} else {
// 处理文本内容
textContent := choice.Message.StringContent()
if textContent != "" {
part := dto.GeminiPart{
Text: textContent,
}
content.Parts = append(content.Parts, part)
}
}
candidate.Content = content
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
}
return geminiResponse
}
// StreamResponseOpenAI2Gemini 将 OpenAI 流式响应转换为 Gemini 格式
func StreamResponseOpenAI2Gemini(openAIResponse *dto.ChatCompletionsStreamResponse, info *relaycommon.RelayInfo) *dto.GeminiChatResponse {
// 检查是否有实际内容或结束标志
hasContent := false
hasFinishReason := false
for _, choice := range openAIResponse.Choices {
if len(choice.Delta.GetContentString()) > 0 || (choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0) {
hasContent = true
}
if choice.FinishReason != nil {
hasFinishReason = true
}
}
// 如果没有实际内容且没有结束标志,跳过。主要针对 openai 流响应开头的空数据
if !hasContent && !hasFinishReason {
return nil
}
geminiResponse := &dto.GeminiChatResponse{
Candidates: make([]dto.GeminiChatCandidate, 0, len(openAIResponse.Choices)),
PromptFeedback: dto.GeminiChatPromptFeedback{
SafetyRatings: []dto.GeminiChatSafetyRating{},
},
UsageMetadata: dto.GeminiUsageMetadata{
PromptTokenCount: info.PromptTokens,
CandidatesTokenCount: 0, // 流式响应中可能没有完整的 usage 信息
TotalTokenCount: info.PromptTokens,
},
}
for _, choice := range openAIResponse.Choices {
candidate := dto.GeminiChatCandidate{
Index: int64(choice.Index),
SafetyRatings: []dto.GeminiChatSafetyRating{},
}
// 设置结束原因
if choice.FinishReason != nil {
var finishReason string
switch *choice.FinishReason {
case "stop":
finishReason = "STOP"
case "length":
finishReason = "MAX_TOKENS"
case "content_filter":
finishReason = "SAFETY"
case "tool_calls":
finishReason = "STOP"
default:
finishReason = "STOP"
}
candidate.FinishReason = &finishReason
}
// 转换消息内容
content := dto.GeminiChatContent{
Role: "model",
Parts: make([]dto.GeminiPart, 0),
}
// 处理工具调用
if choice.Delta.ToolCalls != nil {
for _, toolCall := range choice.Delta.ToolCalls {
// 解析参数
var args map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil {
args = map[string]interface{}{"arguments": toolCall.Function.Arguments}
}
} else {
args = make(map[string]interface{})
}
part := dto.GeminiPart{
FunctionCall: &dto.FunctionCall{
FunctionName: toolCall.Function.Name,
Arguments: args,
},
}
content.Parts = append(content.Parts, part)
}
} else {
// 处理文本内容
textContent := choice.Delta.GetContentString()
if textContent != "" {
part := dto.GeminiPart{
Text: textContent,
}
content.Parts = append(content.Parts, part)
}
}
candidate.Content = content
geminiResponse.Candidates = append(geminiResponse.Candidates, candidate)
}
return geminiResponse
}

View File

@@ -1,7 +1,6 @@
package service
import (
"encoding/json"
"errors"
"fmt"
"io"
@@ -63,7 +62,7 @@ func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeError
text = "请求上游地址失败"
}
}
claudeError := dto.ClaudeError{
claudeError := types.ClaudeError{
Message: text,
Type: "new_api_error",
}
@@ -80,10 +79,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
}
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = &types.NewAPIError{
StatusCode: resp.StatusCode,
ErrorType: types.ErrorTypeOpenAIError,
}
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -105,8 +101,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
// General format error (OpenAI, Anthropic, Gemini, etc.)
newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
} else {
newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
newApiErr.ErrorType = types.ErrorTypeOpenAIError
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
}
return
}
@@ -116,7 +111,7 @@ func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string)
return
}
statusCodeMapping := make(map[string]string)
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
if err != nil {
return
}

View File

@@ -3,6 +3,7 @@ package setting
import (
"encoding/json"
"fmt"
"math"
"one-api/common"
"sync"
)
@@ -58,6 +59,9 @@ func CheckModelRequestRateLimitGroup(jsonStr string) error {
if limits[0] < 0 || limits[1] < 1 {
return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
}
if limits[0] > math.MaxInt32 || limits[1] > math.MaxInt32 {
return fmt.Errorf("group %s [%d, %d] has max rate limits value 2147483647", group, limits[0], limits[1])
}
}
return nil

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/http"
"one-api/common"
"strings"
)
@@ -15,8 +16,8 @@ type OpenAIError struct {
}
type ClaudeError struct {
Message string `json:"message,omitempty"`
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
}
type ErrorType string
@@ -28,6 +29,7 @@ const (
ErrorTypeMidjourneyError ErrorType = "midjourney_error"
ErrorTypeGeminiError ErrorType = "gemini_error"
ErrorTypeRerankError ErrorType = "rerank_error"
ErrorTypeUpstreamError ErrorType = "upstream_error"
)
type ErrorCode string
@@ -62,6 +64,7 @@ const (
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code"
ErrorCodeBadResponse ErrorCode = "bad_response"
ErrorCodeBadResponseBody ErrorCode = "bad_response_body"
ErrorCodeEmptyResponse ErrorCode = "empty_response"
// sql error
ErrorCodeQueryDataError ErrorCode = "query_data_error"
@@ -73,11 +76,13 @@ const (
)
type NewAPIError struct {
Err error
RelayError any
ErrorType ErrorType
errorCode ErrorCode
StatusCode int
Err error
RelayError any
skipRetry bool
recordErrorLog *bool
errorType ErrorType
errorCode ErrorCode
StatusCode int
}
func (e *NewAPIError) GetErrorCode() ErrorCode {
@@ -87,6 +92,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode {
return e.errorCode
}
func (e *NewAPIError) GetErrorType() ErrorType {
if e == nil {
return ""
}
return e.errorType
}
func (e *NewAPIError) Error() string {
if e == nil {
return ""
@@ -98,19 +110,30 @@ func (e *NewAPIError) Error() string {
return e.Err.Error()
}
func (e *NewAPIError) MaskSensitiveError() string {
if e == nil {
return ""
}
if e.Err == nil {
return string(e.errorCode)
}
return common.MaskSensitiveInfo(e.Err.Error())
}
func (e *NewAPIError) SetMessage(message string) {
e.Err = errors.New(message)
}
func (e *NewAPIError) ToOpenAIError() OpenAIError {
switch e.ErrorType {
var result OpenAIError
switch e.errorType {
case ErrorTypeOpenAIError:
if openAIError, ok := e.RelayError.(OpenAIError); ok {
return openAIError
result = openAIError
}
case ErrorTypeClaudeError:
if claudeError, ok := e.RelayError.(ClaudeError); ok {
return OpenAIError{
result = OpenAIError{
Message: e.Error(),
Type: claudeError.Type,
Param: "",
@@ -118,85 +141,122 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
}
}
}
return OpenAIError{
result = OpenAIError{
Message: e.Error(),
Type: string(e.ErrorType),
Type: string(e.errorType),
Param: "",
Code: e.errorCode,
}
result.Message = common.MaskSensitiveInfo(result.Message)
return result
}
func (e *NewAPIError) ToClaudeError() ClaudeError {
switch e.ErrorType {
var result ClaudeError
switch e.errorType {
case ErrorTypeOpenAIError:
openAIError := e.RelayError.(OpenAIError)
return ClaudeError{
result = ClaudeError{
Message: e.Error(),
Type: fmt.Sprintf("%v", openAIError.Code),
}
case ErrorTypeClaudeError:
return e.RelayError.(ClaudeError)
result = e.RelayError.(ClaudeError)
default:
return ClaudeError{
result = ClaudeError{
Message: e.Error(),
Type: string(e.ErrorType),
Type: string(e.errorType),
}
}
result.Message = common.MaskSensitiveInfo(result.Message)
return result
}
func NewError(err error, errorCode ErrorCode) *NewAPIError {
return &NewAPIError{
type NewAPIErrorOptions func(*NewAPIError)
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
e := &NewAPIError{
Err: err,
RelayError: nil,
ErrorType: ErrorTypeNewAPIError,
errorType: ErrorTypeNewAPIError,
StatusCode: http.StatusInternalServerError,
errorCode: errorCode,
}
for _, op := range ops {
op(e)
}
return e
}
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
openaiError := OpenAIError{
Message: err.Error(),
Type: string(errorCode),
}
return WithOpenAIError(openaiError, statusCode)
return WithOpenAIError(openaiError, statusCode, ops...)
}
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
return &NewAPIError{
func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
openaiError := OpenAIError{
Type: string(errorCode),
}
return WithOpenAIError(openaiError, statusCode, ops...)
}
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
e := &NewAPIError{
Err: err,
RelayError: OpenAIError{
Message: err.Error(),
Type: string(errorCode),
},
ErrorType: ErrorTypeNewAPIError,
errorType: ErrorTypeNewAPIError,
StatusCode: statusCode,
errorCode: errorCode,
}
for _, op := range ops {
op(e)
}
return e
}
func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
code, ok := openAIError.Code.(string)
if !ok {
code = fmt.Sprintf("%v", openAIError.Code)
}
return &NewAPIError{
if openAIError.Type == "" {
openAIError.Type = "upstream_error"
}
e := &NewAPIError{
RelayError: openAIError,
ErrorType: ErrorTypeOpenAIError,
errorType: ErrorTypeOpenAIError,
StatusCode: statusCode,
Err: errors.New(openAIError.Message),
errorCode: ErrorCode(code),
}
for _, op := range ops {
op(e)
}
return e
}
func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
return &NewAPIError{
func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
if claudeError.Type == "" {
claudeError.Type = "upstream_error"
}
e := &NewAPIError{
RelayError: claudeError,
ErrorType: ErrorTypeClaudeError,
errorType: ErrorTypeClaudeError,
StatusCode: statusCode,
Err: errors.New(claudeError.Message),
errorCode: ErrorCode(claudeError.Type),
}
for _, op := range ops {
op(e)
}
return e
}
func IsChannelError(err *NewAPIError) bool {
@@ -206,10 +266,33 @@ func IsChannelError(err *NewAPIError) bool {
return strings.HasPrefix(string(err.errorCode), "channel:")
}
func IsLocalError(err *NewAPIError) bool {
func IsSkipRetryError(err *NewAPIError) bool {
if err == nil {
return false
}
return err.ErrorType == ErrorTypeNewAPIError
return err.skipRetry
}
func ErrOptionWithSkipRetry() NewAPIErrorOptions {
return func(e *NewAPIError) {
e.skipRetry = true
}
}
func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
return func(e *NewAPIError) {
e.recordErrorLog = common.GetPointer(false)
}
}
func IsRecordErrorLog(e *NewAPIError) bool {
if e == nil {
return false
}
if e.recordErrorLog == nil {
// default to true if not set
return true
}
return *e.recordErrorLog
}

View File

@@ -21,6 +21,7 @@
"lucide-react": "^0.511.0",
"marked": "^4.1.1",
"mermaid": "^11.6.0",
"qrcode.react": "^4.2.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
@@ -1492,6 +1493,8 @@
"punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="],
"qrcode.react": ["qrcode.react@4.2.0", "", { "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA=="],
"quansync": ["quansync@0.2.10", "", {}, "sha512-t41VRkMYbkHyCYmOvx/6URnN80H7k4X0lLdBMGsz+maAwrJQYB1djpV6vHrQIBE0WBSGqhtEHrK9U3DWWH8v7A=="],
"query-string": ["query-string@9.2.0", "", { "dependencies": { "decode-uri-component": "^0.4.1", "filter-obj": "^5.1.0", "split-on-first": "^3.0.0" } }, "sha512-YIRhrHujoQxhexwRLxfy3VSjOXmvZRd2nyw1PwL1UUqZ/ys1dEZd1+NSgXkne2l/4X/7OXkigEAuhTX0g/ivJQ=="],
@@ -1502,7 +1505,7 @@
"rc-checkbox": ["rc-checkbox@3.5.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "^2.3.2", "rc-util": "^5.25.2" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-aOAQc3E98HteIIsSqm6Xk2FPKIER6+5vyEFMZfo73TqM+VVAIqOkHoPjgKLqSNtVLWScoaM7vY2ZrGEheI79yg=="],
"rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="],
"rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="],
"rc-dialog": ["rc-dialog@9.6.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "@rc-component/portal": "^1.0.0-8", "classnames": "^2.2.6", "rc-motion": "^2.3.0", "rc-util": "^5.21.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-ApoVi9Z8PaCQg6FsUzS8yvBEQy0ZL2PkuvAgrmohPkN3okps5WZ5WQWPc1RNuiOKaAYv8B97ACdsFU5LizzCqg=="],
@@ -1946,8 +1949,6 @@
"@lobehub/ui/lucide-react": ["lucide-react@0.484.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-oZy8coK9kZzvqhSgfbGkPtTgyjpBvs3ukLgDPv14dSOZtBtboryWF5o8i3qen7QbGg7JhiJBz5mK1p8YoMZTLQ=="],
"@lobehub/ui/rc-collapse": ["rc-collapse@4.0.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-SwoOByE39/3oIokDs/BnkqI+ltwirZbP8HZdq1/3SkPSBi7xDdvWHTp7cpNI9ullozkR6mwTWQi6/E/9huQVrA=="],
"@radix-ui/react-dismissable-layer/@radix-ui/react-compose-refs": ["@radix-ui/react-compose-refs@1.0.0", "", { "dependencies": { "@babel/runtime": "^7.13.10" }, "peerDependencies": { "react": "^16.8 || ^17.0 || ^18.0" } }, "sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA=="],
"@radix-ui/react-popper/@floating-ui/react-dom": ["@floating-ui/react-dom@0.7.2", "", { "dependencies": { "@floating-ui/dom": "^0.5.3", "use-isomorphic-layout-effect": "^1.1.1" }, "peerDependencies": { "react": ">=16.8.0", "react-dom": ">=16.8.0" } }, "sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg=="],
@@ -1964,6 +1965,8 @@
"@visactor/vrender-kits/roughjs": ["roughjs@4.5.2", "", { "dependencies": { "path-data-parser": "^0.1.0", "points-on-curve": "^0.2.0", "points-on-path": "^0.2.1" } }, "sha512-2xSlLDKdsWyFxrveYWk9YQ/Y9UfK38EAMRNkYkMqYBJvPX8abCa9PN0x3w02H8Oa6/0bcZICJU+U95VumPqseg=="],
"antd/rc-collapse": ["rc-collapse@3.9.0", "", { "dependencies": { "@babel/runtime": "^7.10.1", "classnames": "2.x", "rc-motion": "^2.3.4", "rc-util": "^5.27.0" }, "peerDependencies": { "react": ">=16.9.0", "react-dom": ">=16.9.0" } }, "sha512-swDdz4QZ4dFTo4RAUMLL50qP0EY62N2kvmk2We5xYdRwcRn8WcYtuetCJpwpaCbUfUt5+huLpVxhvmnK+PHrkA=="],
"antd/scroll-into-view-if-needed": ["scroll-into-view-if-needed@3.1.0", "", { "dependencies": { "compute-scroll-into-view": "^3.0.2" } }, "sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ=="],
"chokidar/glob-parent": ["glob-parent@5.1.2", "", { "dependencies": { "is-glob": "^4.0.1" } }, "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow=="],

View File

@@ -21,6 +21,7 @@
"lucide-react": "^0.511.0",
"marked": "^4.1.1",
"mermaid": "^11.6.0",
"qrcode.react": "^4.2.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",

View File

@@ -50,6 +50,7 @@ import { IconGithubLogo, IconMail, IconLock } from '@douyinfe/semi-icons';
import OIDCIcon from '../common/logo/OIDCIcon.js';
import WeChatIcon from '../common/logo/WeChatIcon.js';
import LinuxDoIcon from '../common/logo/LinuxDoIcon.js';
import TwoFAVerification from './TwoFAVerification.js';
import { useTranslation } from 'react-i18next';
const LoginForm = () => {
@@ -78,6 +79,7 @@ const LoginForm = () => {
const [resetPasswordLoading, setResetPasswordLoading] = useState(false);
const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false);
const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false);
const [showTwoFA, setShowTwoFA] = useState(false);
const logo = getLogo();
const systemName = getSystemName();
@@ -162,6 +164,13 @@ const LoginForm = () => {
);
const { success, message, data } = res.data;
if (success) {
// 检查是否需要2FA验证
if (data && data.require_2fa) {
setShowTwoFA(true);
setLoginLoading(false);
return;
}
userDispatch({ type: 'login', payload: data });
setUserData(data);
updateAPI();
@@ -280,6 +289,21 @@ const LoginForm = () => {
setOtherLoginOptionsLoading(false);
};
// 2FA验证成功处理
const handle2FASuccess = (data) => {
userDispatch({ type: 'login', payload: data });
setUserData(data);
updateAPI();
showSuccess('登录成功!');
navigate('/console');
};
// 返回登录页面
const handleBackToLogin = () => {
setShowTwoFA(false);
setInputs({ username: '', password: '', wechat_verification_code: '' });
};
const renderOAuthOptions = () => {
return (
<div className="flex flex-col items-center">
@@ -537,6 +561,35 @@ const LoginForm = () => {
);
};
// 2FA验证弹窗
const render2FAModal = () => {
return (
<Modal
title={
<div className="flex items-center">
<div className="w-8 h-8 rounded-full bg-green-100 dark:bg-green-900 flex items-center justify-center mr-3">
<svg className="w-4 h-4 text-green-600 dark:text-green-400" fill="currentColor" viewBox="0 0 20 20">
<path fillRule="evenodd" d="M6 8a2 2 0 11-4 0 2 2 0 014 0zM8 7a1 1 0 100 2h8a1 1 0 100-2H8zM6 14a2 2 0 11-4 0 2 2 0 014 0zM8 13a1 1 0 100 2h8a1 1 0 100-2H8z" clipRule="evenodd" />
</svg>
</div>
两步验证
</div>
}
visible={showTwoFA}
onCancel={handleBackToLogin}
footer={null}
width={450}
centered
>
<TwoFAVerification
onSuccess={handle2FASuccess}
onBack={handleBackToLogin}
isModal={true}
/>
</Modal>
);
};
return (
<div className="relative overflow-hidden bg-gray-100 flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
{/* 背景模糊晕染球 */}
@@ -547,6 +600,7 @@ const LoginForm = () => {
? renderEmailLoginForm()
: renderOAuthOptions()}
{renderWeChatLoginModal()}
{render2FAModal()}
{turnstileEnabled && (
<div className="flex justify-center mt-6">

View File

@@ -0,0 +1,230 @@
/*
Copyright (C) 2025 QuantumNous
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import { API, showError, showSuccess } from '../../helpers';
import { Button, Card, Divider, Form, Input, Typography } from '@douyinfe/semi-ui';
import React, { useState } from 'react';
const { Title, Text, Paragraph } = Typography;
const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => {
const [loading, setLoading] = useState(false);
const [useBackupCode, setUseBackupCode] = useState(false);
const [verificationCode, setVerificationCode] = useState('');
const handleSubmit = async () => {
if (!verificationCode) {
showError('请输入验证码');
return;
}
// Validate code format
if (useBackupCode && verificationCode.length !== 8) {
showError('备用码必须是8位');
return;
} else if (!useBackupCode && !/^\d{6}$/.test(verificationCode)) {
showError('验证码必须是6位数字');
return;
}
setLoading(true);
try {
const res = await API.post('/api/user/login/2fa', {
code: verificationCode
});
if (res.data.success) {
showSuccess('登录成功');
// 保存用户信息到本地存储
localStorage.setItem('user', JSON.stringify(res.data.data));
if (onSuccess) {
onSuccess(res.data.data);
}
} else {
showError(res.data.message);
}
} catch (error) {
showError('验证失败,请重试');
} finally {
setLoading(false);
}
};
const handleKeyPress = (e) => {
if (e.key === 'Enter') {
handleSubmit();
}
};
if (isModal) {
return (
<div className="space-y-4">
<Paragraph className="text-gray-600 dark:text-gray-300">
请输入认证器应用显示的验证码完成登录
</Paragraph>
<Form onSubmit={handleSubmit}>
<Form.Input
field="code"
label={useBackupCode ? "备用码" : "验证码"}
placeholder={useBackupCode ? "请输入8位备用码" : "请输入6位验证码"}
value={verificationCode}
onChange={setVerificationCode}
onKeyPress={handleKeyPress}
size="large"
style={{ marginBottom: 16 }}
autoFocus
/>
<Button
htmlType="submit"
type="primary"
loading={loading}
block
size="large"
style={{ marginBottom: 16 }}
>
验证并登录
</Button>
</Form>
<Divider />
<div style={{ textAlign: 'center' }}>
<Button
theme="borderless"
type="tertiary"
onClick={() => {
setUseBackupCode(!useBackupCode);
setVerificationCode('');
}}
style={{ marginRight: 16, color: '#1890ff', padding: 0 }}
>
{useBackupCode ? '使用认证器验证码' : '使用备用码'}
</Button>
{onBack && (
<Button
theme="borderless"
type="tertiary"
onClick={onBack}
style={{ color: '#1890ff', padding: 0 }}
>
返回登录
</Button>
)}
</div>
<div className="bg-gray-50 dark:bg-gray-800 rounded-lg p-3">
<Text size="small" type="secondary">
<strong>提示</strong>
<br />
验证码每30秒更新一次
<br />
如果无法获取验证码请使用备用码
<br />
每个备用码只能使用一次
</Text>
</div>
</div>
);
}
return (
<div style={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
minHeight: '60vh'
}}>
<Card style={{ width: 400, padding: 24 }}>
<div style={{ textAlign: 'center', marginBottom: 24 }}>
<Title heading={3}>两步验证</Title>
<Paragraph type="secondary">
请输入认证器应用显示的验证码完成登录
</Paragraph>
</div>
<Form onSubmit={handleSubmit}>
<Form.Input
field="code"
label={useBackupCode ? "备用码" : "验证码"}
placeholder={useBackupCode ? "请输入8位备用码" : "请输入6位验证码"}
value={verificationCode}
onChange={setVerificationCode}
onKeyPress={handleKeyPress}
size="large"
style={{ marginBottom: 16 }}
autoFocus
/>
<Button
htmlType="submit"
type="primary"
loading={loading}
block
size="large"
style={{ marginBottom: 16 }}
>
验证并登录
</Button>
</Form>
<Divider />
<div style={{ textAlign: 'center' }}>
<Button
theme="borderless"
type="tertiary"
onClick={() => {
setUseBackupCode(!useBackupCode);
setVerificationCode('');
}}
style={{ marginRight: 16, color: '#1890ff', padding: 0 }}
>
{useBackupCode ? '使用认证器验证码' : '使用备用码'}
</Button>
{onBack && (
<Button
theme="borderless"
type="tertiary"
onClick={onBack}
style={{ color: '#1890ff', padding: 0 }}
>
返回登录
</Button>
)}
</div>
<div style={{ marginTop: 24, padding: 16, background: '#f6f8fa', borderRadius: 6 }}>
<Text size="small" type="secondary">
<strong>提示</strong>
<br />
验证码每30秒更新一次
<br />
如果无法获取验证码请使用备用码
<br />
每个备用码只能使用一次
</Text>
</div>
</Card>
</div>
);
};
export default TwoFAVerification;

View File

@@ -0,0 +1,622 @@
import React, { useState, useEffect, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
Space,
Button,
Form,
Card,
Typography,
Banner,
Row,
Col,
InputNumber,
Switch,
Select,
Input,
} from '@douyinfe/semi-ui';
import {
IconCode,
IconEdit,
IconPlus,
IconDelete,
IconSetting,
} from '@douyinfe/semi-icons';
const { Text } = Typography;
const JSONEditor = ({
value = '',
onChange,
field,
label,
placeholder,
extraText,
showClear = true,
template,
templateLabel,
editorType = 'keyValue', // keyValue, object, region
autosize = true,
rules = [],
formApi = null,
...props
}) => {
const { t } = useTranslation();
// 初始化JSON数据
const [jsonData, setJsonData] = useState(() => {
// 初始化时解析JSON数据
if (value && value.trim()) {
try {
const parsed = JSON.parse(value);
return parsed;
} catch (error) {
return {};
}
}
return {};
});
// 根据键数量决定默认编辑模式
const [editMode, setEditMode] = useState(() => {
// 如果初始JSON数据的键数量大于10个则默认使用手动模式
if (value && value.trim()) {
try {
const parsed = JSON.parse(value);
const keyCount = Object.keys(parsed).length;
return keyCount > 10 ? 'manual' : 'visual';
} catch (error) {
// JSON无效时默认显示手动编辑模式
return 'manual';
}
}
return 'visual';
});
const [jsonError, setJsonError] = useState('');
// 数据同步 - 当value变化时总是更新jsonData如果JSON有效
useEffect(() => {
try {
const parsed = value && value.trim() ? JSON.parse(value) : {};
setJsonData(parsed);
setJsonError('');
} catch (error) {
console.log('JSON解析失败:', error.message);
setJsonError(error.message);
// JSON格式错误时不更新jsonData
}
}, [value]);
// 处理可视化编辑的数据变化
const handleVisualChange = useCallback((newData) => {
setJsonData(newData);
setJsonError('');
const jsonString = Object.keys(newData).length === 0 ? '' : JSON.stringify(newData, null, 2);
// 通过formApi设置值如果提供的话
if (formApi && field) {
formApi.setValue(field, jsonString);
}
onChange?.(jsonString);
}, [onChange, formApi, field]);
// 处理手动编辑的数据变化
const handleManualChange = useCallback((newValue) => {
onChange?.(newValue);
// 验证JSON格式
if (newValue && newValue.trim()) {
try {
const parsed = JSON.parse(newValue);
setJsonError('');
// 预先准备可视化数据,但不立即应用
// 这样切换到可视化模式时数据已经准备好了
} catch (error) {
setJsonError(error.message);
}
} else {
setJsonError('');
}
}, [onChange]);
// 切换编辑模式
const toggleEditMode = useCallback(() => {
if (editMode === 'visual') {
// 从可视化模式切换到手动模式
setEditMode('manual');
} else {
// 从手动模式切换到可视化模式需要验证JSON
try {
const parsed = value && value.trim() ? JSON.parse(value) : {};
setJsonData(parsed);
setJsonError('');
setEditMode('visual');
} catch (error) {
setJsonError(error.message);
// JSON格式错误时不切换模式
return;
}
}
}, [editMode, value]);
// 添加键值对
const addKeyValue = useCallback(() => {
const newData = { ...jsonData };
const keys = Object.keys(newData);
let newKey = 'key';
let counter = 1;
while (newData.hasOwnProperty(newKey)) {
newKey = `key${counter}`;
counter++;
}
newData[newKey] = '';
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 删除键值对
const removeKeyValue = useCallback((keyToRemove) => {
const newData = { ...jsonData };
delete newData[keyToRemove];
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 更新键名
const updateKey = useCallback((oldKey, newKey) => {
if (oldKey === newKey) return;
const newData = { ...jsonData };
const value = newData[oldKey];
delete newData[oldKey];
newData[newKey] = value;
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 更新值
const updateValue = useCallback((key, newValue) => {
const newData = { ...jsonData };
newData[key] = newValue;
handleVisualChange(newData);
}, [jsonData, handleVisualChange]);
// 填入模板
const fillTemplate = useCallback(() => {
if (template) {
const templateString = JSON.stringify(template, null, 2);
// 通过formApi设置值如果提供的话
if (formApi && field) {
formApi.setValue(field, templateString);
}
// 无论哪种模式都要更新值
onChange?.(templateString);
// 如果是可视化模式同时更新jsonData
if (editMode === 'visual') {
setJsonData(template);
}
// 清除错误状态
setJsonError('');
}
}, [template, onChange, editMode, formApi, field]);
// 渲染键值对编辑器
const renderKeyValueEditor = () => {
if (typeof jsonData !== 'object' || jsonData === null) {
return (
<div className="text-center py-6 px-4">
<div className="text-gray-400 mb-2">
<IconCode size={32} />
</div>
<Text type="tertiary" className="text-gray-500 text-sm">
{t('无效的JSON数据请检查格式')}
</Text>
</div>
);
}
const entries = Object.entries(jsonData);
return (
<div className="space-y-1">
{entries.length === 0 && (
<div className="text-center py-6 px-4">
<div className="text-gray-400 mb-2">
<IconCode size={32} />
</div>
<Text type="tertiary" className="text-gray-500 text-sm">
{t('暂无数据,点击下方按钮添加键值对')}
</Text>
</div>
)}
{entries.map(([key, value], index) => (
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
<Row gutter={12} align="middle">
<Col span={10}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('键名')}</Text>
<Input
placeholder={t('键名')}
value={key}
onChange={(newKey) => updateKey(key, newKey)}
size="small"
/>
</div>
</Col>
<Col span={11}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('值')}</Text>
<Input
placeholder={t('值')}
value={value}
onChange={(newValue) => updateValue(key, newValue)}
size="small"
/>
</div>
</Col>
<Col span={3}>
<div className="flex justify-center pt-4">
<Button
icon={<IconDelete />}
type="danger"
theme="borderless"
size="small"
onClick={() => removeKeyValue(key)}
className="hover:bg-red-50"
/>
</div>
</Col>
</Row>
</Card>
))}
<div className="flex justify-center pt-1">
<Button
icon={<IconPlus />}
onClick={addKeyValue}
size="small"
theme="solid"
type="primary"
className="shadow-sm hover:shadow-md transition-shadow px-4"
>
{t('添加键值对')}
</Button>
</div>
</div>
);
};
// 渲染对象编辑器用于复杂JSON
const renderObjectEditor = () => {
const entries = Object.entries(jsonData);
return (
<div className="space-y-1">
{entries.length === 0 && (
<div className="text-center py-6 px-4">
<div className="text-gray-400 mb-2">
<IconSetting size={32} />
</div>
<Text type="tertiary" className="text-gray-500 text-sm">
{t('暂无参数,点击下方按钮添加请求参数')}
</Text>
</div>
)}
{entries.map(([key, value], index) => (
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
<Row gutter={12} align="middle">
<Col span={8}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('参数名')}</Text>
<Input
placeholder={t('参数名')}
value={key}
onChange={(newKey) => updateKey(key, newKey)}
size="small"
/>
</div>
</Col>
<Col span={13}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('参数值')} ({typeof value})</Text>
{renderValueInput(key, value)}
</div>
</Col>
<Col span={3}>
<div className="flex justify-center pt-4">
<Button
icon={<IconDelete />}
type="danger"
theme="borderless"
size="small"
onClick={() => removeKeyValue(key)}
className="hover:bg-red-50"
/>
</div>
</Col>
</Row>
</Card>
))}
<div className="flex justify-center pt-1">
<Button
icon={<IconPlus />}
onClick={addKeyValue}
size="small"
theme="solid"
type="primary"
className="shadow-sm hover:shadow-md transition-shadow px-4"
>
{t('添加参数')}
</Button>
</div>
</div>
);
};
// 渲染参数值输入控件
const renderValueInput = (key, value) => {
const valueType = typeof value;
if (valueType === 'boolean') {
return (
<div className="flex items-center">
<Switch
checked={value}
onChange={(newValue) => updateValue(key, newValue)}
size="small"
/>
<Text type="tertiary" size="small" className="ml-2">
{value ? t('true') : t('false')}
</Text>
</div>
);
}
if (valueType === 'number') {
return (
<InputNumber
value={value}
onChange={(newValue) => updateValue(key, newValue)}
size="small"
style={{ width: '100%' }}
step={key === 'temperature' ? 0.1 : 1}
precision={key === 'temperature' ? 2 : 0}
placeholder={t('输入数字')}
/>
);
}
// 字符串类型或其他类型
return (
<Input
placeholder={t('参数值')}
value={String(value)}
onChange={(newValue) => {
// 尝试转换为适当的类型
let convertedValue = newValue;
if (newValue === 'true') convertedValue = true;
else if (newValue === 'false') convertedValue = false;
else if (!isNaN(newValue) && newValue !== '' && newValue !== '0') {
convertedValue = Number(newValue);
}
updateValue(key, convertedValue);
}}
size="small"
/>
);
};
// 渲染区域编辑器(特殊格式)
const renderRegionEditor = () => {
const entries = Object.entries(jsonData);
const defaultEntry = entries.find(([key]) => key === 'default');
const modelEntries = entries.filter(([key]) => key !== 'default');
return (
<div className="space-y-1">
{/* 默认区域 */}
<Card className="!p-2 !border-blue-200 !bg-blue-50">
<div className="flex items-center mb-1">
<Text strong size="small" className="text-blue-700">{t('默认区域')}</Text>
</div>
<Input
placeholder={t('默认区域,如: us-central1')}
value={defaultEntry ? defaultEntry[1] : ''}
onChange={(value) => updateValue('default', value)}
size="small"
/>
</Card>
{/* 模型专用区域 */}
<div className="space-y-1">
<Text strong size="small">{t('模型专用区域')}</Text>
{modelEntries.map(([modelName, region], index) => (
<Card key={index} className="!p-3 !border-gray-200 !rounded-md hover:shadow-sm transition-shadow duration-200">
<Row gutter={12} align="middle">
<Col span={10}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('模型名称')}</Text>
<Input
placeholder={t('模型名称')}
value={modelName}
onChange={(newKey) => updateKey(modelName, newKey)}
size="small"
/>
</div>
</Col>
<Col span={11}>
<div className="space-y-1">
<Text type="tertiary" size="small">{t('区域')}</Text>
<Input
placeholder={t('区域')}
value={region}
onChange={(newValue) => updateValue(modelName, newValue)}
size="small"
/>
</div>
</Col>
<Col span={3}>
<div className="flex justify-center pt-4">
<Button
icon={<IconDelete />}
type="danger"
theme="borderless"
size="small"
onClick={() => removeKeyValue(modelName)}
className="hover:bg-red-50"
/>
</div>
</Col>
</Row>
</Card>
))}
<div className="flex justify-center pt-1">
<Button
icon={<IconPlus />}
onClick={addKeyValue}
size="small"
theme="solid"
type="primary"
className="shadow-sm hover:shadow-md transition-shadow px-4"
>
{t('添加模型区域')}
</Button>
</div>
</div>
</div>
);
};
// 渲染可视化编辑器
const renderVisualEditor = () => {
switch (editorType) {
case 'region':
return renderRegionEditor();
case 'object':
return renderObjectEditor();
case 'keyValue':
default:
return renderKeyValueEditor();
}
};
const hasJsonError = jsonError && jsonError.trim() !== '';
return (
<div className="space-y-1">
{/* Label统一显示在上方 */}
{label && (
<div className="flex items-center">
<Text className="text-sm font-medium text-gray-900">{label}</Text>
</div>
)}
{/* 编辑模式切换 */}
<div className="flex items-center justify-between p-2 bg-gray-50 rounded-md">
<div className="flex items-center gap-2">
{editMode === 'visual' && (
<Text type="tertiary" size="small" className="bg-blue-100 text-blue-700 px-2 py-0.5 rounded text-xs">
{t('可视化模式')}
</Text>
)}
{editMode === 'manual' && (
<Text type="tertiary" size="small" className="bg-green-100 text-green-700 px-2 py-0.5 rounded text-xs">
{t('手动编辑模式')}
</Text>
)}
</div>
<div className="flex items-center gap-2">
{template && templateLabel && (
<Button
size="small"
type="tertiary"
onClick={fillTemplate}
className="!text-semi-color-primary hover:bg-blue-50 text-xs"
>
{templateLabel}
</Button>
)}
<Space size="tight">
<Button
size="small"
type={editMode === 'visual' ? 'primary' : 'tertiary'}
icon={<IconEdit />}
onClick={toggleEditMode}
disabled={editMode === 'manual' && hasJsonError}
className={editMode === 'visual' ? 'shadow-sm' : ''}
>
{t('可视化')}
</Button>
<Button
size="small"
type={editMode === 'manual' ? 'primary' : 'tertiary'}
icon={<IconCode />}
onClick={toggleEditMode}
className={editMode === 'manual' ? 'shadow-sm' : ''}
>
{t('手动编辑')}
</Button>
</Space>
</div>
</div>
{/* JSON错误提示 */}
{hasJsonError && (
<Banner
type="danger"
description={`JSON 格式错误: ${jsonError}`}
className="!rounded-md text-sm"
/>
)}
{/* 编辑器内容 */}
{editMode === 'visual' ? (
<div>
<Card className="!p-3 !border-gray-200 !shadow-sm !rounded-md bg-white">
{renderVisualEditor()}
</Card>
{/* 可视化模式下的额外文本显示在下方 */}
{extraText && (
<div className="text-xs text-gray-600 mt-0.5">
{extraText}
</div>
)}
{/* 隐藏的Form字段用于验证和数据绑定 */}
<Form.Input
field={field}
value={value}
rules={rules}
style={{ display: 'none' }}
noLabel={true}
{...props}
/>
</div>
) : (
<Form.TextArea
field={field}
placeholder={placeholder}
value={value}
onChange={handleManualChange}
showClear={showClear}
rows={Math.max(8, value ? value.split('\n').length : 8)}
rules={rules}
noLabel={true}
{...props}
/>
)}
{/* 额外文本在手动编辑模式下显示 */}
{extraText && editMode === 'manual' && (
<div className="text-xs text-gray-600">
{extraText}
</div>
)}
</div>
);
};
export default JSONEditor;

View File

@@ -33,7 +33,7 @@ import {
Settings,
} from 'lucide-react';
import { useTranslation } from 'react-i18next';
import { renderGroupOption, modelSelectFilter } from '../../helpers';
import { renderGroupOption, selectFilter } from '../../helpers';
import ParameterControl from './ParameterControl';
import ImageUrlInput from './ImageUrlInput';
import ConfigManager from './ConfigManager';
@@ -173,7 +173,7 @@ const SettingsPanel = ({
name='model'
required
selection
filter={modelSelectFilter}
filter={selectFilter}
autoClearSearchValue={false}
onChange={(value) => onInputChange('model', value)}
value={inputs.model}

View File

@@ -36,6 +36,7 @@ import {
renderModelTag,
getModelCategories
} from '../../helpers';
import TwoFASetting from './TwoFASetting';
import Turnstile from 'react-turnstile';
import { UserContext } from '../../context/User';
import { useTheme } from '../../context/Theme';
@@ -1041,6 +1042,9 @@ const PersonalSetting = () => {
</div>
</Card>
{/* 两步验证设置 */}
<TwoFASetting />
{/* 危险区域 */}
<Card
className="!rounded-xl border-red-200 w-full"

Some files were not shown because too many files have changed in this diff Show More