Compare commits

..

48 Commits

Author SHA1 Message Date
t0ng7u
e967094348 Merge branch 'sub' into feature/subscription 2026-02-03 02:10:04 +08:00
t0ng7u
47012e84c1 fix: standardize epay success response schema
Return subscription epay pay success responses via ApiSuccess to include the consistent success field and align with error schema.
2026-02-03 02:09:53 +08:00
t0ng7u
b8b40511f3 Merge branch 'sub' into feature/subscription 2026-02-03 02:07:12 +08:00
t0ng7u
58afec3771 fix: refine Japanese subscription status labels
Adjust Japanese UI wording for active-count labels to read more naturally and consistently.
2026-02-03 02:05:40 +08:00
t0ng7u
e48b74f469 Merge branch 'sub' into feature/subscription 2026-02-03 02:03:47 +08:00
t0ng7u
c1061b2d18 🛡️ fix: fail fast on epay form parse errors
Handle ParseForm errors in epay notify/return handlers by returning fail or redirecting to failure, avoiding unsafe fallback to query parameters.
2026-02-03 02:03:25 +08:00
t0ng7u
4e9c5bb45b Merge branch 'sub' into feature/subscription 2026-02-03 01:59:05 +08:00
t0ng7u
f578aa8e00 🔧 fix: harden billing flow and sidebar settings
Add missing strings import for subscription fallback checks, log failed subscription refunds after retries, and extend sidebar module settings with a subscription management toggle plus translations.
2026-02-03 01:58:49 +08:00
t0ng7u
732484ceaa Merge branch 'sub' into feature/subscription 2026-02-03 01:51:31 +08:00
t0ng7u
f521a430ce 🔧 fix: harden epay callbacks and billing fallbacks
Use POST and form parsing for epay notify/return routes, persist epay orders before provider calls with expiry on failure, and ensure notify handlers retry correctly.
Restrict subscription-first fallback to insufficient-subscription errors and log refund failures after retries to avoid silent quota drift.
2026-02-03 01:51:16 +08:00
t0ng7u
11eef1ce77 Merge branch 'sub' into feature/subscription 2026-02-03 01:29:45 +08:00
t0ng7u
1e2c039f40 Merge remote-tracking branch 'newapi/main' into sub 2026-02-03 01:29:19 +08:00
t0ng7u
3d177f3020 Merge branch 'sub' into feature/subscription 2026-02-03 01:24:40 +08:00
t0ng7u
0486a5d83b 🧾 fix: persist epay orders before purchase
Create the subscription order before initiating epay payment and expire it if the provider call fails, preventing orphaned transactions and improving reconciliation.
2026-02-03 01:24:25 +08:00
t0ng7u
2cdc37fdc4 Merge branch 'sub' into feature/subscription 2026-02-03 00:24:16 +08:00
t0ng7u
49ac355357 🔧 fix: normalize epay error handling and webhook retries
Standardize SubscriptionRequestEpay error responses via ApiErrorMsg for a consistent schema.
Return "fail" on non-success trade statuses in the epay webhook to preserve retry behavior.
2026-02-03 00:23:51 +08:00
t0ng7u
414f86fb4b Merge branch 'sub' into feature/subscription 2026-02-03 00:12:20 +08:00
t0ng7u
6b694c9c94 🚦 fix: guard epay return success on order completion
Redirect subscription return flow to failure when order completion fails, preventing false success states after payment verification.
2026-02-03 00:10:07 +08:00
t0ng7u
b942d4eebd Merge branch 'sub' into feature/subscription 2026-02-03 00:02:04 +08:00
t0ng7u
ef44a341a8 🔧 fix: make epay webhook and return flow subscription-aware
Ensure Epay webhook acknowledges success only after order completion, returning fail on processing errors to allow retries. Redirect subscription payment returns to the subscription page instead of top-up for correct user flow.
2026-02-03 00:01:24 +08:00
t0ng7u
70a8b30aab Merge branch 'sub' into feature/subscription 2026-02-02 23:45:05 +08:00
t0ng7u
34e5720773 feat: harden subscription billing and improve UI consistency
Improve subscription payment safety and data integrity by handling user/URL lookup failures, fixing Stripe subscription mode, persisting quota reset fields, and correcting subscription delta accounting and DB timestamp casting. Refine the UI with stricter custom duration validation, accurate currency rounding, conditional Epay labeling, rollback on preference update failure, and shared subscription formatting helpers plus clearer component naming.
2026-02-02 23:44:53 +08:00
t0ng7u
4057eedaff Merge branch 'sub' into feature/subscription 2026-02-02 23:09:44 +08:00
t0ng7u
1fba3c064b Add full i18n coverage for subscription-related UI across locales 2026-02-02 23:09:27 +08:00
t0ng7u
120256a52c 🚀 chore: Remove useless action 2026-02-02 17:06:15 +08:00
t0ng7u
16349c98cb Merge remote-tracking branch 'newapi/main' into sub
# Conflicts:
#	main.go
#	web/src/components/table/usage-logs/UsageLogsColumnDefs.jsx
#	web/src/pages/Setting/Payment/SettingsPaymentGatewayCreem.jsx
2026-02-02 17:03:02 +08:00
t0ng7u
a74cc93bbc 🔧 chore: remove unused Creem settings state
Drop the unused originInputs state and redundant updates to keep the Creem
settings form state minimal and easier to maintain.
2026-02-02 13:00:37 +08:00
t0ng7u
e8bd2e0d53 chore: Add upgrade group guidance in subscription editor
Add explanatory helper text under the upgrade group field to clarify automatic group upgrades, rollback conditions, and the expected delay before downgrading takes effect.
2026-02-01 15:47:34 +08:00
t0ng7u
de90e11cdf feat: Extract quota conversion helpers to shared utils
Move quota display/conversion helpers into web/src/helpers/quota.js and update the subscription plan editor to import and use the shared utilities instead of inline functions.
2026-02-01 15:40:26 +08:00
t0ng7u
f0e60df96e feat: Update subscription purchase modal display
Show total quota as currency with tooltip for raw quota, hide reset cycle when never, and display upgrade group when configured to match card display rules.
2026-02-01 02:28:50 +08:00
t0ng7u
96caec1626 feat: Add subscription upgrade group with auto downgrade 2026-02-01 02:17:17 +08:00
t0ng7u
c22ca9cdb3 🚀 chore: Remove duplicate subscription usage percentage display
Keep the usage percentage shown only in the total quota line to avoid redundant “已用 0%” text while preserving remaining days in the summary.
2026-02-01 00:43:09 +08:00
t0ng7u
6300c31d70 🚀 refactor: Simplify subscription quota to total amount model
Remove per-model subscription items and switch to a single total quota per plan and user subscription. Update billing, reset, and logging flows to operate on total quota, and refactor admin/user UI to configure and display total quota consistently.
2026-02-01 00:35:08 +08:00
t0ng7u
b92a4ee987 🎨 style: tag color to white 2026-01-31 15:05:09 +08:00
t0ng7u
cf67af3b14 feat: Add subscription limits and UI tags consistency
Add per-plan purchase limits with backend enforcement and UI disable states.
Expose limit configuration in admin plan editor and show limits in plan tables/cards.
Refine subscription UI tags with unified badge style and streamlined “My Subscriptions” layout.
2026-01-31 15:02:03 +08:00
t0ng7u
2297af731c 🔧 chore: Unify subscription plan status toggle with PATCH endpoint
Replace separate enable/disable flows with a single PATCH API that updates the enabled flag.
Update frontend hooks and table actions to call the unified endpoint and keep UI behavior consistent.
Introduce a minimal admin controller handler and route for the status update.
2026-01-31 14:27:01 +08:00
t0ng7u
28c5feb570 💸 chore: Align subscription pricing display with global currency settings
Unify subscription price rendering to use the site-wide currency symbol/rate on the wallet and admin views.
Make subscription plan currency read-only in the editor and force USD on create/update to avoid drift.
Use global currency display type when creating Creem checkout payloads.
2026-01-31 13:41:55 +08:00
t0ng7u
354da6ea6b 🔧 ci: Change workflow trigger to sub branch
Update the Docker image workflow to run on pushes to the sub branch instead of main.
2026-01-31 13:19:26 +08:00
t0ng7u
a0c23a0648 🐛 fix(subscription): avoid pre-consume lookup noise
Use a RowsAffected check for the idempotency lookup so missing records
no longer surface as "record not found" errors while preserving behavior.
2026-01-31 01:18:47 +08:00
t0ng7u
41489fc32a feat(subscription): cache plan lookups and stabilize pre-consume
Introduce hybrid caches for subscription plans, items, and plan info with explicit
invalidation on admin updates. Streamline pre-consume transactions to reduce
redundant queries while preserving idempotency and reset logic.
2026-01-31 01:12:54 +08:00
t0ng7u
ffebb35499 feat(subscription): harden subscription billing with resets, idempotency, and production-grade stability
Add plan-level quota reset periods and display/reset cadence in admin/UI
Enforce natural reset alignment with background reset task and cleanup job
Make subscription pre-consume/refund idempotent with request-scoped records and retries
Use database time for consistent resets across multi-instance deployments
Harden payment callbacks with locking and idempotent order completion
Record subscription purchases in topup history and billing logs
Optimize subscription queries and add critical composite indexes
2026-01-31 00:31:47 +08:00
t0ng7u
5707ee3492 feat(subscription): add quota reset periods and admin configuration
- Add reset period fields on subscription plans and user items
- Apply automatic quota resets during pre-consume based on plan schedule
- Expose reset-period configuration in the admin plan editor
- Display reset cadence in subscription cards and purchase modal
- Validate custom reset seconds on plan create/update
2026-01-31 00:06:13 +08:00
t0ng7u
ecf50b754a 🎨 style: format all code with gofmt and lint:fix
Apply consistent code formatting across the entire codebase using
gofmt and lint:fix tools. This ensures adherence to Go community
standards and improves code readability and maintainability.

Changes include:
- Run gofmt on all .go files to standardize formatting
- Apply lint:fix to automatically resolve linting issues
- Fix code style inconsistencies and formatting violations

No functional changes were made in this commit.
2026-01-30 23:43:27 +08:00
t0ng7u
697cbbf752 fix(subscription): finalize payments, log billing, and clean up dead code
Complete subscription orders by creating a matching top-up record and writing billing logs
Add Epay return handler to verify and finalize browser callbacks
Require Stripe/Creem webhook configuration before starting subscription payments
Show subscription purchases in topup history with clearer labels/methods
Remove unused subscription helper, legacy Creem webhook struct, and unused topup fields
Simplify subscription self API payload to active/all lists only
2026-01-30 23:40:01 +08:00
t0ng7u
a60783e99f feat(admin): streamline subscription plan benefits editor with bulk actions
Restore the avatar/icon header for the “Model Benefits” section
Replace scattered controls with a compact toolbar-style workflow
Support multi-select add with a default quota for new items
Add row selection with bulk apply-to-selected / apply-to-all quota updates
Enable delete-selected to manage benefits faster and reduce mistakes
2026-01-30 16:24:51 +08:00
t0ng7u
348ae6df73 feat(admin): add user subscription management and refine UI/pagination
Add admin APIs to list/create/invalidate/delete user subscriptions
Add model helpers to fetch all user subscriptions (incl. expired) and support cancel/hard-delete
Wire new admin routes for user subscription operations
Replace “Bind subscription plan” entry with a dedicated User Subscriptions SideSheet in Users table
Use CardTable with responsive layout and working client-side pagination inside the SideSheet
Improve subscription purchase modal empty-gateway state with a Banner notice
2026-01-30 14:29:56 +08:00
t0ng7u
009910b960 feat: add subscription billing system with admin management and user purchase flow
Implement a new subscription-based billing model alongside existing metered/per-request billing:

Backend:
- Add subscription plan models (SubscriptionPlan, SubscriptionPlanItem, UserSubscription, etc.)
- Implement CRUD APIs for subscription plan management (admin only)
- Add user subscription queries with support for multiple active/expired subscriptions
- Integrate payment gateways (Stripe, Creem, Epay) for subscription purchases
- Implement pre-consume and post-consume billing logic for subscription quota tracking
- Add billing preference settings (subscription_first, wallet_first, etc.)
- Enhance usage logs with subscription deduction details

Frontend - Admin:
- Add subscription management page with table view and drawer-based edit form
- Match UI/UX style with existing admin pages (redemption codes, users)
- Support enabling/disabling plans, configuring payment IDs, and model quotas
- Add user subscription binding modal in user management

Frontend - Wallet:
- Add subscription plans card with current subscription status display
- Show all subscriptions (active and expired) with remaining days/usage percentage
- Display purchasable plans with pricing cards following SaaS best practices
- Extract purchase modal to separate component matching payment confirm modal style
- Add skeleton loading states with active animation
- Implement billing preference selector in card header
- Handle payment gateway availability based on admin configuration

Frontend - Usage Logs:
- Display subscription deduction details in log entries
- Show step-by-step breakdown of subscription usage (pre-consumed, delta, final, remaining)
- Add subscription deduction tag for subscription-covered requests
2026-01-30 05:31:10 +08:00
t0ng7u
c6c12d340f ci: create docker automation 2026-01-30 01:58:59 +08:00
170 changed files with 5086 additions and 9466 deletions

38
.gitattributes vendored
View File

@@ -1,38 +0,0 @@
# Auto detect text files and perform LF normalization
* text=auto
# Go files
*.go text eol=lf
# Config files
*.json text eol=lf
*.yaml text eol=lf
*.yml text eol=lf
*.toml text eol=lf
*.md text eol=lf
# JavaScript/TypeScript files
*.js text eol=lf
*.jsx text eol=lf
*.ts text eol=lf
*.tsx text eol=lf
*.html text eol=lf
*.css text eol=lf
# Shell scripts
*.sh text eol=lf
# Binary files
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.ico binary
*.woff binary
*.woff2 binary
# ============================================
# GitHub Linguist - Language Detection
# ============================================
# Mark web frontend as vendored so GitHub recognizes this as a Go project
electron/** linguist-vendored

View File

@@ -4,12 +4,6 @@ on:
push:
tags:
- '*'
workflow_dispatch:
inputs:
tag:
description: 'Tag name to build (e.g., v0.10.8-alpha.3)'
required: true
type: string
jobs:
build_single_arch:
@@ -31,24 +25,15 @@ jobs:
contents: read
steps:
- name: Check out
- name: Check out (shallow)
uses: actions/checkout@v4
with:
fetch-depth: ${{ github.event_name == 'workflow_dispatch' && 0 || 1 }}
ref: ${{ github.event.inputs.tag || github.ref }}
fetch-depth: 1
- name: Resolve tag & write VERSION
run: |
if [ -n "${{ github.event.inputs.tag }}" ]; then
TAG="${{ github.event.inputs.tag }}"
# Verify tag exists
if ! git rev-parse "refs/tags/$TAG" >/dev/null 2>&1; then
echo "Error: Tag '$TAG' does not exist in the repository"
exit 1
fi
else
TAG=${GITHUB_REF#refs/tags/}
fi
git fetch --tags --force --depth=1
TAG=${GITHUB_REF#refs/tags/}
echo "TAG=$TAG" >> $GITHUB_ENV
echo "$TAG" > VERSION
echo "Building tag: $TAG for ${{ matrix.arch }}"
@@ -102,15 +87,10 @@ jobs:
name: Create multi-arch manifests (Docker Hub)
needs: [build_single_arch]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch'
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Extract tag
run: |
if [ -n "${{ github.event.inputs.tag }}" ]; then
echo "TAG=${{ github.event.inputs.tag }}" >> $GITHUB_ENV
else
echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
fi
run: echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
#
# - name: Normalize GHCR repository
# run: echo "GHCR_REPOSITORY=${GITHUB_REPOSITORY,,}" >> $GITHUB_ENV

View File

@@ -445,14 +445,6 @@ Bienvenue à toutes les formes de contribution!
---
## 📜 Licence
Ce projet est sous licence [GNU Affero General Public License v3.0 (AGPLv3)](./LICENSE).
Si les politiques de votre organisation ne permettent pas l'utilisation de logiciels sous licence AGPLv3, ou si vous souhaitez éviter les obligations open-source de l'AGPLv3, veuillez nous contacter à : [support@quantumnous.com](mailto:support@quantumnous.com)
---
## 🌟 Historique des étoiles
<div align="center">

View File

@@ -445,14 +445,6 @@ docker run --name new-api -d --restart always \
---
## 📜 ライセンス
このプロジェクトは [GNU Affero General Public License v3.0 (AGPLv3)](./LICENSE) の下でライセンスされています。
お客様の組織のポリシーがAGPLv3ライセンスのソフトウェアの使用を許可していない場合、またはAGPLv3のオープンソース義務を回避したい場合は、こちらまでお問い合わせください[support@quantumnous.com](mailto:support@quantumnous.com)
---
## 🌟 スター履歴
<div align="center">

View File

@@ -445,14 +445,6 @@ Welcome all forms of contribution!
---
## 📜 License
This project is licensed under the [GNU Affero General Public License v3.0 (AGPLv3)](./LICENSE).
If your organization's policies do not permit the use of AGPLv3-licensed software, or if you wish to avoid the open-source obligations of AGPLv3, please contact us at: [support@quantumnous.com](mailto:support@quantumnous.com)
---
## 🌟 Star History
<div align="center">

View File

@@ -445,14 +445,6 @@ docker run --name new-api -d --restart always \
---
## 📜 许可证
本项目采用 [GNU Affero 通用公共许可证 v3.0 (AGPLv3)](./LICENSE) 授权。
如果您所在的组织政策不允许使用 AGPLv3 许可的软件,或您希望规避 AGPLv3 的开源义务,请发送邮件至:[support@quantumnous.com](mailto:support@quantumnous.com)
---
## 🌟 Star History
<div align="center">

View File

@@ -5,9 +5,12 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
)
// BodyStorage 请求体存储接口
@@ -98,10 +101,25 @@ type diskStorage struct {
}
func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
// 使用统一的缓存目录管理
filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody)
// 确定缓存目录
dir := cachePath
if dir == "" {
dir = os.TempDir()
}
dir = filepath.Join(dir, "new-api-body-cache")
// 确保目录存在
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create cache directory: %w", err)
}
// 创建临时文件
filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano())
filePath := filepath.Join(dir, filename)
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
// 写入数据
@@ -130,10 +148,25 @@ func newDiskStorage(data []byte, cachePath string) (*diskStorage, error) {
}
func newDiskStorageFromReader(reader io.Reader, maxBytes int64, cachePath string) (*diskStorage, error) {
// 使用统一的缓存目录管理
filePath, file, err := CreateDiskCacheFile(DiskCacheTypeBody)
// 确定缓存目录
dir := cachePath
if dir == "" {
dir = os.TempDir()
}
dir = filepath.Join(dir, "new-api-body-cache")
// 确保目录存在
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create cache directory: %w", err)
}
// 创建临时文件
filename := fmt.Sprintf("body-%s-%d.tmp", uuid.New().String()[:8], time.Now().UnixNano())
filePath := filepath.Join(dir, filename)
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
// 从 reader 读取并写入文件
@@ -304,6 +337,29 @@ func CreateBodyStorageFromReader(reader io.Reader, contentLength int64, maxBytes
// CleanupOldCacheFiles 清理旧的缓存文件(用于启动时清理残留)
func CleanupOldCacheFiles() {
// 使用统一的缓存管理
CleanupOldDiskCacheFiles(5 * time.Minute)
cachePath := GetDiskCachePath()
if cachePath == "" {
cachePath = os.TempDir()
}
dir := filepath.Join(cachePath, "new-api-body-cache")
entries, err := os.ReadDir(dir)
if err != nil {
return // 目录不存在或无法读取
}
now := time.Now()
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
// 删除超过 5 分钟的旧文件
if now.Sub(info.ModTime()) > 5*time.Minute {
os.Remove(filepath.Join(dir, entry.Name()))
}
}
}

View File

@@ -39,7 +39,7 @@ var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 1000
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
@@ -175,10 +175,6 @@ var (
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
// Per-user search rate limit (applies after authentication, keyed by user ID)
SearchRateLimitNum = 10
SearchRateLimitDuration int64 = 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute

View File

@@ -1,176 +0,0 @@
package common
import (
"fmt"
"os"
"path/filepath"
"time"
"github.com/google/uuid"
)
// DiskCacheType 磁盘缓存类型
type DiskCacheType string
const (
DiskCacheTypeBody DiskCacheType = "body" // 请求体缓存
DiskCacheTypeFile DiskCacheType = "file" // 文件数据缓存
)
// 统一的缓存目录名
const diskCacheDir = "new-api-body-cache"
// GetDiskCacheDir 获取统一的磁盘缓存目录
// 注意:每次调用都会重新计算,以响应配置变化
func GetDiskCacheDir() string {
cachePath := GetDiskCachePath()
if cachePath == "" {
cachePath = os.TempDir()
}
return filepath.Join(cachePath, diskCacheDir)
}
// EnsureDiskCacheDir 确保缓存目录存在
func EnsureDiskCacheDir() error {
dir := GetDiskCacheDir()
return os.MkdirAll(dir, 0755)
}
// CreateDiskCacheFile 创建磁盘缓存文件
// cacheType: 缓存类型body/file
// 返回文件路径和文件句柄
func CreateDiskCacheFile(cacheType DiskCacheType) (string, *os.File, error) {
if err := EnsureDiskCacheDir(); err != nil {
return "", nil, fmt.Errorf("failed to create cache directory: %w", err)
}
dir := GetDiskCacheDir()
filename := fmt.Sprintf("%s-%s-%d.tmp", cacheType, uuid.New().String()[:8], time.Now().UnixNano())
filePath := filepath.Join(dir, filename)
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_RDWR|os.O_EXCL, 0600)
if err != nil {
return "", nil, fmt.Errorf("failed to create cache file: %w", err)
}
return filePath, file, nil
}
// WriteDiskCacheFile 写入数据到磁盘缓存文件
// 返回文件路径
func WriteDiskCacheFile(cacheType DiskCacheType, data []byte) (string, error) {
filePath, file, err := CreateDiskCacheFile(cacheType)
if err != nil {
return "", err
}
_, err = file.Write(data)
if err != nil {
file.Close()
os.Remove(filePath)
return "", fmt.Errorf("failed to write cache file: %w", err)
}
if err := file.Close(); err != nil {
os.Remove(filePath)
return "", fmt.Errorf("failed to close cache file: %w", err)
}
return filePath, nil
}
// WriteDiskCacheFileString 写入字符串到磁盘缓存文件
func WriteDiskCacheFileString(cacheType DiskCacheType, data string) (string, error) {
return WriteDiskCacheFile(cacheType, []byte(data))
}
// ReadDiskCacheFile 读取磁盘缓存文件
func ReadDiskCacheFile(filePath string) ([]byte, error) {
return os.ReadFile(filePath)
}
// ReadDiskCacheFileString 读取磁盘缓存文件为字符串
func ReadDiskCacheFileString(filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
return string(data), nil
}
// RemoveDiskCacheFile 删除磁盘缓存文件
func RemoveDiskCacheFile(filePath string) error {
return os.Remove(filePath)
}
// CleanupOldDiskCacheFiles 清理旧的缓存文件
// maxAge: 文件最大存活时间
// 注意:此函数只删除文件,不更新统计(因为无法知道每个文件的原始大小)
func CleanupOldDiskCacheFiles(maxAge time.Duration) error {
dir := GetDiskCacheDir()
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return nil // 目录不存在,无需清理
}
return err
}
now := time.Now()
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if now.Sub(info.ModTime()) > maxAge {
// 注意:后台清理任务删除文件时,由于无法得知原始 base64Size
// 只能按磁盘文件大小扣减。这在目前 base64 存储模式下是准确的。
if err := os.Remove(filepath.Join(dir, entry.Name())); err == nil {
DecrementDiskFiles(info.Size())
}
}
}
return nil
}
// GetDiskCacheInfo 获取磁盘缓存目录信息
func GetDiskCacheInfo() (fileCount int, totalSize int64, err error) {
dir := GetDiskCacheDir()
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return 0, 0, nil
}
return 0, 0, err
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
fileCount++
totalSize += info.Size()
}
return fileCount, totalSize, nil
}
// ShouldUseDiskCache 判断是否应该使用磁盘缓存
func ShouldUseDiskCache(dataSize int64) bool {
if !IsDiskCacheEnabled() {
return false
}
threshold := GetDiskCacheThresholdBytes()
if dataSize < threshold {
return false
}
return IsDiskCacheAvailable(dataSize)
}

View File

@@ -113,12 +113,8 @@ func IncrementDiskFiles(size int64) {
// DecrementDiskFiles 减少磁盘文件计数
func DecrementDiskFiles(size int64) {
if atomic.AddInt64(&diskCacheStats.ActiveDiskFiles, -1) < 0 {
atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, 0)
}
if atomic.AddInt64(&diskCacheStats.CurrentDiskUsageBytes, -size) < 0 {
atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, 0)
}
atomic.AddInt64(&diskCacheStats.ActiveDiskFiles, -1)
atomic.AddInt64(&diskCacheStats.CurrentDiskUsageBytes, -size)
}
// IncrementMemoryBuffers 增加内存缓存计数
@@ -143,29 +139,12 @@ func IncrementMemoryCacheHits() {
atomic.AddInt64(&diskCacheStats.MemoryCacheHits, 1)
}
// ResetDiskCacheStats 重置命中统计信息(不重置当前使用量)
// ResetDiskCacheStats 重置统计信息(不重置当前使用量)
func ResetDiskCacheStats() {
atomic.StoreInt64(&diskCacheStats.DiskCacheHits, 0)
atomic.StoreInt64(&diskCacheStats.MemoryCacheHits, 0)
}
// ResetDiskCacheUsage 重置磁盘缓存使用量统计(用于清理缓存后)
func ResetDiskCacheUsage() {
atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, 0)
atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, 0)
}
// SyncDiskCacheStats 从实际磁盘状态同步统计信息
// 用于修正统计与实际不符的情况
func SyncDiskCacheStats() {
fileCount, totalSize, err := GetDiskCacheInfo()
if err != nil {
return
}
atomic.StoreInt64(&diskCacheStats.ActiveDiskFiles, int64(fileCount))
atomic.StoreInt64(&diskCacheStats.CurrentDiskUsageBytes, totalSize)
}
// IsDiskCacheAvailable 检查是否可以创建新的磁盘缓存
func IsDiskCacheAvailable(requestSize int64) bool {
if !IsDiskCacheEnabled() {

View File

@@ -218,39 +218,6 @@ func ApiSuccess(c *gin.Context, data any) {
})
}
// ApiErrorI18n returns a translated error message based on the user's language preference
// key is the i18n message key, args is optional template data
func ApiErrorI18n(c *gin.Context, key string, args ...map[string]any) {
msg := TranslateMessage(c, key, args...)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": msg,
})
}
// ApiSuccessI18n returns a translated success message based on the user's language preference
func ApiSuccessI18n(c *gin.Context, key string, data any, args ...map[string]any) {
msg := TranslateMessage(c, key, args...)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": msg,
"data": data,
})
}
// TranslateMessage is a helper function that calls i18n.T
// This function is defined here to avoid circular imports
// The actual implementation will be set during init
var TranslateMessage func(c *gin.Context, key string, args ...map[string]any) string
func init() {
// Default implementation that returns the key as-is
// This will be replaced by i18n.T during i18n initialization
TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string {
return key
}
}
func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
requestBody, err := GetRequestBody(c)
if err != nil {

View File

@@ -137,6 +137,7 @@ func initConstantEnv() {
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
constant.NotifyLimitCount = GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
constant.NotificationLimitDurationMinute = GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
// GenerateDefaultToken 是否生成初始令牌,默认关闭。

View File

@@ -1,33 +0,0 @@
package common
import "sync/atomic"
// PerformanceMonitorConfig 性能监控配置
type PerformanceMonitorConfig struct {
Enabled bool
CPUThreshold int
MemoryThreshold int
DiskThreshold int
}
var performanceMonitorConfig atomic.Value
func init() {
// 初始化默认配置
performanceMonitorConfig.Store(PerformanceMonitorConfig{
Enabled: true,
CPUThreshold: 90,
MemoryThreshold: 90,
DiskThreshold: 90,
})
}
// GetPerformanceMonitorConfig 获取性能监控配置
func GetPerformanceMonitorConfig() PerformanceMonitorConfig {
return performanceMonitorConfig.Load().(PerformanceMonitorConfig)
}
// SetPerformanceMonitorConfig 设置性能监控配置
func SetPerformanceMonitorConfig(config PerformanceMonitorConfig) {
performanceMonitorConfig.Store(config)
}

View File

@@ -1,81 +0,0 @@
package common
import (
"sync/atomic"
"time"
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
)
// DiskSpaceInfo 磁盘空间信息
type DiskSpaceInfo struct {
// 总空间(字节)
Total uint64 `json:"total"`
// 可用空间(字节)
Free uint64 `json:"free"`
// 已用空间(字节)
Used uint64 `json:"used"`
// 使用百分比
UsedPercent float64 `json:"used_percent"`
}
// SystemStatus 系统状态信息
type SystemStatus struct {
CPUUsage float64
MemoryUsage float64
DiskUsage float64
}
var latestSystemStatus atomic.Value
func init() {
latestSystemStatus.Store(SystemStatus{})
}
// StartSystemMonitor 启动系统监控
func StartSystemMonitor() {
go func() {
for {
config := GetPerformanceMonitorConfig()
if !config.Enabled {
time.Sleep(30 * time.Second)
continue
}
updateSystemStatus()
time.Sleep(5 * time.Second)
}
}()
}
func updateSystemStatus() {
var status SystemStatus
// CPU
// 注意cpu.Percent(0, false) 返回自上次调用以来的 CPU 使用率
// 如果是第一次调用,可能会返回错误或不准确的值,但在循环中会逐渐正常
percents, err := cpu.Percent(0, false)
if err == nil && len(percents) > 0 {
status.CPUUsage = percents[0]
}
// Memory
memInfo, err := mem.VirtualMemory()
if err == nil {
status.MemoryUsage = memInfo.UsedPercent
}
// Disk
diskInfo := GetDiskSpaceInfo()
if diskInfo.Total > 0 {
status.DiskUsage = diskInfo.UsedPercent
}
latestSystemStatus.Store(status)
}
// GetSystemStatus 获取当前系统状态
func GetSystemStatus() SystemStatus {
return latestSystemStatus.Load().(SystemStatus)
}

View File

@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
return fmt.Sprintf("%f", inter.(float64))
case bool:
if inter.(bool) {
return "true"

View File

@@ -56,13 +56,7 @@ const (
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
// ContextKeyFileSourcesToCleanup stores file sources that need cleanup when request ends
ContextKeyFileSourcesToCleanup ContextKey = "file_sources_to_cleanup"
// ContextKeyAdminRejectReason stores an admin-only reject/block reason extracted from upstream responses.
// It is not returned to end users, but can be persisted into consume/error logs for debugging.
ContextKeyAdminRejectReason ContextKey = "admin_reject_reason"
// ContextKeyLanguage stores the user's language preference for i18n
ContextKeyLanguage ContextKey = "language"
)

View File

@@ -11,6 +11,7 @@ var GetMediaTokenNotStream bool
var UpdateTask bool
var MaxRequestBodyMB int
var AzureDefaultAPIVersion string
var GeminiVisionMaxImageNum int
var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool

View File

@@ -89,8 +89,7 @@ func GetAllChannels(c *gin.Context) {
if enableTagMode {
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.SysError("failed to get paginated tags: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取标签失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
@@ -137,8 +136,7 @@ func GetAllChannels(c *gin.Context) {
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
if err != nil {
common.SysError("failed to get channels: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道列表失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
}
@@ -643,8 +641,7 @@ func RefreshCodexChannelCredential(c *gin.Context) {
oauthKey, ch, err := service.RefreshCodexChannelCredential(ctx, channelId, service.CodexCredentialRefreshOptions{ResetCaches: true})
if err != nil {
common.SysError("failed to refresh codex channel credential: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "刷新凭证失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
@@ -1318,8 +1315,7 @@ func CopyChannel(c *gin.Context) {
// fetch original channel with key
origin, err := model.GetChannelById(id, true)
if err != nil {
common.SysError("failed to get channel by id: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取渠道信息失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
@@ -1337,8 +1333,7 @@ func CopyChannel(c *gin.Context) {
// insert
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
common.SysError("failed to clone channel: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "复制渠道失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
model.InitChannelCache()

View File

@@ -132,8 +132,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
code, state, err := parseCodexAuthorizationInput(req.Input)
if err != nil {
common.SysError("failed to parse codex authorization input: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析授权信息失败,请检查输入格式"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
if strings.TrimSpace(code) == "" {
@@ -178,8 +177,7 @@ func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) {
tokenRes, err := service.ExchangeCodexAuthorizationCode(ctx, code, verifier)
if err != nil {
common.SysError("failed to exchange codex authorization code: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}

View File

@@ -45,8 +45,7 @@ func GetCodexChannelUsage(c *gin.Context) {
oauthKey, err := codex.ParseOAuthKey(strings.TrimSpace(ch.Key))
if err != nil {
common.SysError("failed to parse oauth key: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析凭证失败,请检查渠道配置"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
accessToken := strings.TrimSpace(oauthKey.AccessToken)
@@ -71,8 +70,7 @@ func GetCodexChannelUsage(c *gin.Context) {
statusCode, body, err := service.FetchCodexWhamUsage(ctx, client, ch.GetBaseURL(), accessToken, accountID)
if err != nil {
common.SysError("failed to fetch codex usage: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
@@ -101,8 +99,7 @@ func GetCodexChannelUsage(c *gin.Context) {
defer cancel2()
statusCode, body, err = service.FetchCodexWhamUsage(ctx2, client, ch.GetBaseURL(), oauthKey.AccessToken, accountID)
if err != nil {
common.SysError("failed to fetch codex usage after refresh: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取用量信息失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
}

View File

@@ -17,8 +17,7 @@ func MigrateConsoleSetting(c *gin.Context) {
// 读取全部 option
opts, err := model.AllOption()
if err != nil {
common.SysError("failed to get all options: " + err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "获取配置失败,请稍后重试"})
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
return
}
// 建立 map

View File

@@ -1,386 +0,0 @@
package controller
import (
"net/http"
"strconv"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-gonic/gin"
)
// CustomOAuthProviderResponse is the response structure for custom OAuth providers
// It excludes sensitive fields like client_secret
type CustomOAuthProviderResponse struct {
Id int `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
}
func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse {
return &CustomOAuthProviderResponse{
Id: p.Id,
Name: p.Name,
Slug: p.Slug,
Enabled: p.Enabled,
ClientId: p.ClientId,
AuthorizationEndpoint: p.AuthorizationEndpoint,
TokenEndpoint: p.TokenEndpoint,
UserInfoEndpoint: p.UserInfoEndpoint,
Scopes: p.Scopes,
UserIdField: p.UserIdField,
UsernameField: p.UsernameField,
DisplayNameField: p.DisplayNameField,
EmailField: p.EmailField,
WellKnown: p.WellKnown,
AuthStyle: p.AuthStyle,
}
}
// GetCustomOAuthProviders returns all custom OAuth providers
func GetCustomOAuthProviders(c *gin.Context) {
providers, err := model.GetAllCustomOAuthProviders()
if err != nil {
common.ApiError(c, err)
return
}
response := make([]*CustomOAuthProviderResponse, len(providers))
for i, p := range providers {
response[i] = toCustomOAuthProviderResponse(p)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": response,
})
}
// GetCustomOAuthProvider returns a single custom OAuth provider by ID
func GetCustomOAuthProvider(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiErrorMsg(c, "无效的 ID")
return
}
provider, err := model.GetCustomOAuthProviderById(id)
if err != nil {
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": toCustomOAuthProviderResponse(provider),
})
}
// CreateCustomOAuthProviderRequest is the request structure for creating a custom OAuth provider
type CreateCustomOAuthProviderRequest struct {
Name string `json:"name" binding:"required"`
Slug string `json:"slug" binding:"required"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id" binding:"required"`
ClientSecret string `json:"client_secret" binding:"required"`
AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"`
TokenEndpoint string `json:"token_endpoint" binding:"required"`
UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
func CreateCustomOAuthProvider(c *gin.Context) {
var req CreateCustomOAuthProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
return
}
// Check if slug is already taken
if model.IsSlugTaken(req.Slug, 0) {
common.ApiErrorMsg(c, "该 Slug 已被使用")
return
}
// Check if slug conflicts with built-in providers
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
return
}
provider := &model.CustomOAuthProvider{
Name: req.Name,
Slug: req.Slug,
Enabled: req.Enabled,
ClientId: req.ClientId,
ClientSecret: req.ClientSecret,
AuthorizationEndpoint: req.AuthorizationEndpoint,
TokenEndpoint: req.TokenEndpoint,
UserInfoEndpoint: req.UserInfoEndpoint,
Scopes: req.Scopes,
UserIdField: req.UserIdField,
UsernameField: req.UsernameField,
DisplayNameField: req.DisplayNameField,
EmailField: req.EmailField,
WellKnown: req.WellKnown,
AuthStyle: req.AuthStyle,
}
if err := model.CreateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
return
}
// Register the provider in the OAuth registry
oauth.RegisterOrUpdateCustomProvider(provider)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "创建成功",
"data": toCustomOAuthProviderResponse(provider),
})
}
// UpdateCustomOAuthProviderRequest is the request structure for updating a custom OAuth provider
type UpdateCustomOAuthProviderRequest struct {
Name string `json:"name"`
Slug string `json:"slug"`
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"` // Optional: if empty, keep existing
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"user_info_endpoint"`
Scopes string `json:"scopes"`
UserIdField string `json:"user_id_field"`
UsernameField string `json:"username_field"`
DisplayNameField string `json:"display_name_field"`
EmailField string `json:"email_field"`
WellKnown string `json:"well_known"`
AuthStyle int `json:"auth_style"`
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
func UpdateCustomOAuthProvider(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiErrorMsg(c, "无效的 ID")
return
}
var req UpdateCustomOAuthProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiErrorMsg(c, "无效的请求参数: "+err.Error())
return
}
// Get existing provider
provider, err := model.GetCustomOAuthProviderById(id)
if err != nil {
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
return
}
oldSlug := provider.Slug
// Check if new slug is taken by another provider
if req.Slug != "" && req.Slug != provider.Slug {
if model.IsSlugTaken(req.Slug, id) {
common.ApiErrorMsg(c, "该 Slug 已被使用")
return
}
// Check if slug conflicts with built-in providers
if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) {
common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突")
return
}
}
// Update fields
if req.Name != "" {
provider.Name = req.Name
}
if req.Slug != "" {
provider.Slug = req.Slug
}
provider.Enabled = req.Enabled
if req.ClientId != "" {
provider.ClientId = req.ClientId
}
if req.ClientSecret != "" {
provider.ClientSecret = req.ClientSecret
}
if req.AuthorizationEndpoint != "" {
provider.AuthorizationEndpoint = req.AuthorizationEndpoint
}
if req.TokenEndpoint != "" {
provider.TokenEndpoint = req.TokenEndpoint
}
if req.UserInfoEndpoint != "" {
provider.UserInfoEndpoint = req.UserInfoEndpoint
}
if req.Scopes != "" {
provider.Scopes = req.Scopes
}
if req.UserIdField != "" {
provider.UserIdField = req.UserIdField
}
if req.UsernameField != "" {
provider.UsernameField = req.UsernameField
}
if req.DisplayNameField != "" {
provider.DisplayNameField = req.DisplayNameField
}
if req.EmailField != "" {
provider.EmailField = req.EmailField
}
provider.WellKnown = req.WellKnown
provider.AuthStyle = req.AuthStyle
if err := model.UpdateCustomOAuthProvider(provider); err != nil {
common.ApiError(c, err)
return
}
// Update the provider in the OAuth registry
if oldSlug != provider.Slug {
oauth.UnregisterCustomProvider(oldSlug)
}
oauth.RegisterOrUpdateCustomProvider(provider)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "更新成功",
"data": toCustomOAuthProviderResponse(provider),
})
}
// DeleteCustomOAuthProvider deletes a custom OAuth provider
func DeleteCustomOAuthProvider(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
common.ApiErrorMsg(c, "无效的 ID")
return
}
// Get existing provider to get slug
provider, err := model.GetCustomOAuthProviderById(id)
if err != nil {
common.ApiErrorMsg(c, "未找到该 OAuth 提供商")
return
}
// Check if there are any user bindings
count, _ := model.GetBindingCountByProviderId(id)
if count > 0 {
common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。")
return
}
if err := model.DeleteCustomOAuthProvider(id); err != nil {
common.ApiError(c, err)
return
}
// Unregister the provider from the OAuth registry
oauth.UnregisterCustomProvider(provider.Slug)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "删除成功",
})
}
// GetUserOAuthBindings returns all OAuth bindings for the current user
func GetUserOAuthBindings(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
common.ApiErrorMsg(c, "未登录")
return
}
bindings, err := model.GetUserOAuthBindingsByUserId(userId)
if err != nil {
common.ApiError(c, err)
return
}
// Build response with provider info
type BindingResponse struct {
ProviderId int `json:"provider_id"`
ProviderName string `json:"provider_name"`
ProviderSlug string `json:"provider_slug"`
ProviderUserId string `json:"provider_user_id"`
}
response := make([]BindingResponse, 0)
for _, binding := range bindings {
provider, err := model.GetCustomOAuthProviderById(binding.ProviderId)
if err != nil {
continue // Skip if provider not found
}
response = append(response, BindingResponse{
ProviderId: binding.ProviderId,
ProviderName: provider.Name,
ProviderSlug: provider.Slug,
ProviderUserId: binding.ProviderUserId,
})
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": response,
})
}
// UnbindCustomOAuth unbinds a custom OAuth provider from the current user
func UnbindCustomOAuth(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
common.ApiErrorMsg(c, "未登录")
return
}
providerIdStr := c.Param("provider_id")
providerId, err := strconv.Atoi(providerIdStr)
if err != nil {
common.ApiErrorMsg(c, "无效的提供商 ID")
return
}
if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "解绑成功",
})
}

223
controller/discord.go Normal file
View File

@@ -0,0 +1,223 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type DiscordResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type DiscordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res.Body.Close()
var discordResponse DiscordResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
return nil, err
}
if discordResponse.AccessToken == "" {
common.SysError("Discord 获取 Token 失败,请检查设置!")
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("Discord 获取用户信息失败!请检查设置!")
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
}
var discordUser DiscordUser
err = json.NewDecoder(res2.Body).Decode(&discordUser)
if err != nil {
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
common.SysError("Discord 获取用户信息为空!请检查设置!")
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
}
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
if discordUser.ID != "" {
user.Username = discordUser.ID
} else {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if discordUser.Name != "" {
user.DisplayName = discordUser.Name
} else {
user.DisplayName = "Discord User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.DiscordId = discordUser.UID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}

240
controller/github.go Normal file
View File

@@ -0,0 +1,240 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type GitHubOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type GitHubUser struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse GitHubOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://api.github.com/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
}
defer res2.Body.Close()
var githubUser GitHubUser
err = json.NewDecoder(res2.Body).Decode(&githubUser)
if err != nil {
return nil, err
}
if githubUser.Login == "" {
return nil, errors.New("返回值非法,用户字段为空,请稍后重试!")
}
return &githubUser, nil
}
func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
GitHubBind(c)
return
}
if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
})
return
}
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
GitHubId: githubUser.Login,
}
// IsGitHubIdAlreadyTaken is unscoped
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
// FillUserByGitHubId is scoped
err := user.FillUserByGitHubId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// if user.Id == 0 , user has been deleted
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" {
user.DisplayName = githubUser.Name
} else {
user.DisplayName = "GitHub User"
}
user.Email = githubUser.Email
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {
if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 GitHub 登录以及注册",
})
return
}
code := c.Query("code")
githubUser, err := getGitHubUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
GitHubId: githubUser.Login,
}
if model.IsGitHubIdAlreadyTaken(user.GitHubId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 GitHub 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.GitHubId = githubUser.Login
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
affCode := c.Query("aff")
if affCode != "" {
session.Set("aff", affCode)
}
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}

268
controller/linuxdo.go Normal file
View File

@@ -0,0 +1,268 @@
package controller
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type LinuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func LinuxDoBind(c *gin.Context) {
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Linux DO 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.LinuxDOId = strconv.Itoa(linuxdoUser.Id)
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}
func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error) {
if code == "" {
return nil, errors.New("invalid code")
}
// Get access token using Basic auth
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to connect to Linux DO server")
}
defer res.Body.Close()
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
return nil, err
}
if tokenRes.AccessToken == "" {
return nil, fmt.Errorf("failed to get access token: %s", tokenRes.Message)
}
// Get user info
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
req, err = http.NewRequest("GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+tokenRes.AccessToken)
req.Header.Set("Accept", "application/json")
res2, err := client.Do(req)
if err != nil {
return nil, errors.New("failed to get user info from Linux DO")
}
defer res2.Body.Close()
var linuxdoUser LinuxdoUser
if err := json.NewDecoder(res2.Body).Decode(&linuxdoUser); err != nil {
return nil, err
}
if linuxdoUser.Id == 0 {
return nil, errors.New("invalid user info returned")
}
return &linuxdoUser, nil
}
func LinuxdoOAuth(c *gin.Context) {
session := sessions.Default(c)
errorCode := c.Query("error")
if errorCode != "" {
errorDescription := c.Query("error_description")
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": errorDescription,
})
return
}
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LinuxDoBind(c)
return
}
if !common.LinuxDOOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Linux DO 登录以及注册",
})
return
}
code := c.Query("code")
linuxdoUser, err := getLinuxdoUserInfoByCode(code, c)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
LinuxDOId: strconv.Itoa(linuxdoUser.Id),
}
// Check if user exists
if model.IsLinuxDOIdAlreadyTaken(user.LinuxDOId) {
err := user.FillUserByLinuxDOId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已注销",
})
return
}
} else {
if common.RegisterEnabled {
if linuxdoUser.TrustLevel >= common.LinuxDOMinimumTrustLevel {
user.Username = "linuxdo_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = linuxdoUser.Name
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Linux DO 信任等级未达到管理员设置的最低信任等级",
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}

View File

@@ -20,8 +20,7 @@ func GetAllLogs(c *gin.Context) {
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group")
requestId := c.Query("request_id")
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group, requestId)
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), channel, group)
if err != nil {
common.ApiError(c, err)
return
@@ -41,8 +40,7 @@ func GetUserLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
group := c.Query("group")
requestId := c.Query("request_id")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group, requestId)
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, pageInfo.GetStartIdx(), pageInfo.GetPageSize(), group)
if err != nil {
common.ApiError(c, err)
return
@@ -53,32 +51,40 @@ func GetUserLogs(c *gin.Context) {
return
}
// Deprecated: SearchAllLogs 已废弃,前端未使用该接口。
func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该接口已废弃",
"success": true,
"message": "",
"data": logs,
})
return
}
// Deprecated: SearchUserLogs 已废弃,前端未使用该接口。
func SearchUserLogs(c *gin.Context) {
keyword := c.Query("keyword")
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该接口已废弃",
"success": true,
"message": "",
"data": logs,
})
return
}
func GetLogByKey(c *gin.Context) {
tokenId := c.GetInt("token_id")
if tokenId == 0 {
c.JSON(200, gin.H{
"success": false,
"message": "无效的令牌",
})
return
}
logs, err := model.GetLogByTokenId(tokenId)
key := c.Query("key")
logs, err := model.GetLogByKey(key)
if err != nil {
c.JSON(200, gin.H{
"success": false,
@@ -102,11 +108,7 @@ func GetLogsStat(c *gin.Context) {
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group")
stat, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
if err != nil {
common.ApiError(c, err)
return
}
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -129,11 +131,7 @@ func GetLogsSelfStat(c *gin.Context) {
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group")
quotaNum, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
if err != nil {
common.ApiError(c, err)
return
}
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{
"success": true,

View File

@@ -10,7 +10,6 @@ import (
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/QuantumNous/new-api/setting"
"github.com/QuantumNous/new-api/setting/console_setting"
"github.com/QuantumNous/new-api/setting/operation_setting"
@@ -130,30 +129,6 @@ func GetStatus(c *gin.Context) {
data["faq"] = console_setting.GetFAQ()
}
// Add enabled custom OAuth providers
customProviders := oauth.GetEnabledCustomProviders()
if len(customProviders) > 0 {
type CustomOAuthInfo struct {
Name string `json:"name"`
Slug string `json:"slug"`
ClientId string `json:"client_id"`
AuthorizationEndpoint string `json:"authorization_endpoint"`
Scopes string `json:"scopes"`
}
providersInfo := make([]CustomOAuthInfo, 0, len(customProviders))
for _, p := range customProviders {
config := p.GetConfig()
providersInfo = append(providersInfo, CustomOAuthInfo{
Name: config.Name,
Slug: config.Slug,
ClientId: config.ClientId,
AuthorizationEndpoint: config.AuthorizationEndpoint,
Scopes: config.Scopes,
})
}
data["custom_oauth_providers"] = providersInfo
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

View File

@@ -272,8 +272,7 @@ func SyncUpstreamModels(c *gin.Context) {
// 1) 获取未配置模型列表
missing, err := model.GetMissingModels()
if err != nil {
common.SysError("failed to get missing models: " + err.Error())
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取模型列表失败,请稍后重试"})
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}

View File

@@ -1,312 +0,0 @@
package controller
import (
"fmt"
"net/http"
"strconv"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// providerParams returns map with Provider key for i18n templates
func providerParams(name string) map[string]any {
return map[string]any{"Provider": name}
}
// GenerateOAuthCode generates a state code for OAuth CSRF protection
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := common.GetRandomString(12)
affCode := c.Query("aff")
if affCode != "" {
session.Set("aff", affCode)
}
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}
// HandleOAuth handles OAuth callback for all standard OAuth providers
func HandleOAuth(c *gin.Context) {
providerName := c.Param("provider")
provider := oauth.GetProvider(providerName)
if provider == nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": i18n.T(c, i18n.MsgOAuthUnknownProvider),
})
return
}
session := sessions.Default(c)
// 1. Validate state (CSRF protection)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": i18n.T(c, i18n.MsgOAuthStateInvalid),
})
return
}
// 2. Check if user is already logged in (bind flow)
username := session.Get("username")
if username != nil {
handleOAuthBind(c, provider)
return
}
// 3. Check if provider is enabled
if !provider.IsEnabled() {
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
return
}
// 4. Handle error from provider
errorCode := c.Query("error")
if errorCode != "" {
errorDescription := c.Query("error_description")
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": errorDescription,
})
return
}
// 5. Exchange code for token
code := c.Query("code")
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
if err != nil {
handleOAuthError(c, err)
return
}
// 6. Get user info
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
if err != nil {
handleOAuthError(c, err)
return
}
// 7. Find or create user
user, err := findOrCreateOAuthUser(c, provider, oauthUser, session)
if err != nil {
switch err.(type) {
case *OAuthUserDeletedError:
common.ApiErrorI18n(c, i18n.MsgOAuthUserDeleted)
case *OAuthRegistrationDisabledError:
common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
default:
common.ApiError(c, err)
}
return
}
// 8. Check user status
if user.Status != common.UserStatusEnabled {
common.ApiErrorI18n(c, i18n.MsgOAuthUserBanned)
return
}
// 9. Setup login
setupLogin(user, c)
}
// handleOAuthBind handles binding OAuth account to existing user
func handleOAuthBind(c *gin.Context, provider oauth.Provider) {
if !provider.IsEnabled() {
common.ApiErrorI18n(c, i18n.MsgOAuthNotEnabled, providerParams(provider.GetName()))
return
}
// Exchange code for token
code := c.Query("code")
token, err := provider.ExchangeToken(c.Request.Context(), code, c)
if err != nil {
handleOAuthError(c, err)
return
}
// Get user info
oauthUser, err := provider.GetUserInfo(c.Request.Context(), token)
if err != nil {
handleOAuthError(c, err)
return
}
// Check if this OAuth account is already bound (check both new ID and legacy ID)
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
return
}
// Also check legacy ID to prevent duplicate bindings during migration period
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
if provider.IsUserIDTaken(legacyID) {
common.ApiErrorI18n(c, i18n.MsgOAuthAlreadyBound, providerParams(provider.GetName()))
return
}
}
// Get current user from session
session := sessions.Default(c)
id := session.Get("id")
user := model.User{Id: id.(int)}
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
// Handle binding based on provider type
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
// Custom provider: use user_oauth_bindings table
err = model.UpdateUserOAuthBinding(user.Id, genericProvider.GetProviderId(), oauthUser.ProviderUserID)
if err != nil {
common.ApiError(c, err)
return
}
} else {
// Built-in provider: update user record directly
provider.SetProviderUserID(&user, oauthUser.ProviderUserID)
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
}
common.ApiSuccessI18n(c, i18n.MsgOAuthBindSuccess, nil)
}
// findOrCreateOAuthUser finds existing user or creates new user
func findOrCreateOAuthUser(c *gin.Context, provider oauth.Provider, oauthUser *oauth.OAuthUser, session sessions.Session) (*model.User, error) {
user := &model.User{}
// Check if user already exists with new ID
if provider.IsUserIDTaken(oauthUser.ProviderUserID) {
err := provider.FillUserByProviderID(user, oauthUser.ProviderUserID)
if err != nil {
return nil, err
}
// Check if user has been deleted
if user.Id == 0 {
return nil, &OAuthUserDeletedError{}
}
return user, nil
}
// Try to find user with legacy ID (for GitHub migration from login to numeric ID)
if legacyID, ok := oauthUser.Extra["legacy_id"].(string); ok && legacyID != "" {
if provider.IsUserIDTaken(legacyID) {
err := provider.FillUserByProviderID(user, legacyID)
if err != nil {
return nil, err
}
if user.Id != 0 {
// Found user with legacy ID, migrate to new ID
common.SysLog(fmt.Sprintf("[OAuth] Migrating user %d from legacy_id=%s to new_id=%s",
user.Id, legacyID, oauthUser.ProviderUserID))
if err := user.UpdateGitHubId(oauthUser.ProviderUserID); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to migrate user %d: %s", user.Id, err.Error()))
// Continue with login even if migration fails
}
return user, nil
}
}
}
// User doesn't exist, create new user if registration is enabled
if !common.RegisterEnabled {
return nil, &OAuthRegistrationDisabledError{}
}
// Set up new user
user.Username = provider.GetProviderPrefix() + strconv.Itoa(model.GetMaxUserId()+1)
if oauthUser.DisplayName != "" {
user.DisplayName = oauthUser.DisplayName
} else if oauthUser.Username != "" {
user.DisplayName = oauthUser.Username
} else {
user.DisplayName = provider.GetName() + " User"
}
if oauthUser.Email != "" {
user.Email = oauthUser.Email
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
// Handle affiliate code
affCode := session.Get("aff")
inviterId := 0
if affCode != nil {
inviterId, _ = model.GetUserIdByAffCode(affCode.(string))
}
if err := user.Insert(inviterId); err != nil {
return nil, err
}
// For custom providers, create the binding after user is created
if genericProvider, ok := provider.(*oauth.GenericOAuthProvider); ok {
binding := &model.UserOAuthBinding{
UserId: user.Id,
ProviderId: genericProvider.GetProviderId(),
ProviderUserId: oauthUser.ProviderUserID,
}
if err := model.CreateUserOAuthBinding(binding); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to create binding for user %d: %s", user.Id, err.Error()))
// Don't fail the registration, just log the error
}
} else {
// Built-in provider: set the provider user ID on the user model
provider.SetProviderUserID(user, oauthUser.ProviderUserID)
if err := user.Update(false); err != nil {
common.SysError(fmt.Sprintf("[OAuth] Failed to update provider ID for user %d: %s", user.Id, err.Error()))
}
}
return user, nil
}
// Error types for OAuth
type OAuthUserDeletedError struct{}
func (e *OAuthUserDeletedError) Error() string {
return "user has been deleted"
}
type OAuthRegistrationDisabledError struct{}
func (e *OAuthRegistrationDisabledError) Error() string {
return "registration is disabled"
}
// handleOAuthError handles OAuth errors and returns translated message
func handleOAuthError(c *gin.Context, err error) {
switch e := err.(type) {
case *oauth.OAuthError:
if e.Params != nil {
common.ApiErrorI18n(c, e.MsgKey, e.Params)
} else {
common.ApiErrorI18n(c, e.MsgKey)
}
case *oauth.TrustLevelError:
common.ApiErrorI18n(c, i18n.MsgOAuthTrustLevelLow)
default:
common.ApiError(c, err)
}
}

228
controller/oidc.go Normal file
View File

@@ -0,0 +1,228 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetOIDCSettings().ClientId)
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
if oidcResponse.AccessToken == "" {
common.SysLog("OIDC 获取 Token 失败,请检查设置!")
return nil, errors.New("OIDC 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", system_setting.GetOIDCSettings().UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysLog("OIDC 获取用户信息失败!请检查设置!")
return nil, errors.New("OIDC 获取用户信息失败!请检查设置!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
common.SysLog("OIDC 获取用户信息为空!请检查设置!")
return nil, errors.New("OIDC 获取用户信息为空!请检查设置!")
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !system_setting.GetOIDCSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -3,8 +3,8 @@ package controller
import (
"net/http"
"os"
"path/filepath"
"runtime"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/gin-gonic/gin"
@@ -19,7 +19,7 @@ type PerformanceStats struct {
// 磁盘缓存目录信息
DiskCacheInfo DiskCacheInfo `json:"disk_cache_info"`
// 磁盘空间信息
DiskSpaceInfo common.DiskSpaceInfo `json:"disk_space_info"`
DiskSpaceInfo DiskSpaceInfo `json:"disk_space_info"`
// 配置信息
Config PerformanceConfig `json:"config"`
}
@@ -50,6 +50,18 @@ type DiskCacheInfo struct {
TotalSize int64 `json:"total_size"`
}
// DiskSpaceInfo 磁盘空间信息
type DiskSpaceInfo struct {
// 总空间(字节)
Total uint64 `json:"total"`
// 可用空间(字节)
Free uint64 `json:"free"`
// 已用空间(字节)
Used uint64 `json:"used"`
// 使用百分比
UsedPercent float64 `json:"used_percent"`
}
// PerformanceConfig 性能配置
type PerformanceConfig struct {
// 是否启用磁盘缓存
@@ -62,21 +74,11 @@ type PerformanceConfig struct {
DiskCachePath string `json:"disk_cache_path"`
// 是否在容器中运行
IsRunningInContainer bool `json:"is_running_in_container"`
// MonitorEnabled 是否启用性能监控
MonitorEnabled bool `json:"monitor_enabled"`
// MonitorCPUThreshold CPU 使用率阈值(%
MonitorCPUThreshold int `json:"monitor_cpu_threshold"`
// MonitorMemoryThreshold 内存使用率阈值(%
MonitorMemoryThreshold int `json:"monitor_memory_threshold"`
// MonitorDiskThreshold 磁盘使用率阈值(%
MonitorDiskThreshold int `json:"monitor_disk_threshold"`
}
// GetPerformanceStats 获取性能统计信息
func GetPerformanceStats(c *gin.Context) {
// 不再每次获取统计都全量扫描磁盘,依赖原子计数器保证性能
// 仅在系统启动或显式清理时同步
// 获取缓存统计
cacheStats := common.GetDiskCacheStats()
// 获取内存统计
@@ -88,30 +90,16 @@ func GetPerformanceStats(c *gin.Context) {
// 获取配置信息
diskConfig := common.GetDiskCacheConfig()
monitorConfig := common.GetPerformanceMonitorConfig()
config := PerformanceConfig{
DiskCacheEnabled: diskConfig.Enabled,
DiskCacheThresholdMB: diskConfig.ThresholdMB,
DiskCacheMaxSizeMB: diskConfig.MaxSizeMB,
DiskCachePath: diskConfig.Path,
IsRunningInContainer: common.IsRunningInContainer(),
MonitorEnabled: monitorConfig.Enabled,
MonitorCPUThreshold: monitorConfig.CPUThreshold,
MonitorMemoryThreshold: monitorConfig.MemoryThreshold,
MonitorDiskThreshold: monitorConfig.DiskThreshold,
DiskCacheEnabled: diskConfig.Enabled,
DiskCacheThresholdMB: diskConfig.ThresholdMB,
DiskCacheMaxSizeMB: diskConfig.MaxSizeMB,
DiskCachePath: diskConfig.Path,
IsRunningInContainer: common.IsRunningInContainer(),
}
// 获取磁盘空间信息
// 使用缓存的系统状态,避免频繁调用系统 API
systemStatus := common.GetSystemStatus()
diskSpaceInfo := common.DiskSpaceInfo{
UsedPercent: systemStatus.DiskUsage,
}
// 如果需要详细信息,可以按需获取,或者扩展 SystemStatus
// 这里为了保持接口兼容性,我们仍然调用 GetDiskSpaceInfo但注意这可能会有性能开销
// 考虑到 GetPerformanceStats 是管理接口,频率较低,直接调用是可以接受的
// 但为了一致性,我们也可以考虑从 SystemStatus 中获取部分信息
diskSpaceInfo = common.GetDiskSpaceInfo()
diskSpaceInfo := getDiskSpaceInfo()
stats := PerformanceStats{
CacheStats: cacheStats,
@@ -133,19 +121,27 @@ func GetPerformanceStats(c *gin.Context) {
})
}
// ClearDiskCache 清理不活跃的磁盘缓存
// ClearDiskCache 清理磁盘缓存
func ClearDiskCache(c *gin.Context) {
// 清理超过 10 分钟未使用的缓存文件
// 10 分钟是一个安全的阈值,确保正在进行的请求不会被误删
err := common.CleanupOldDiskCacheFiles(10 * time.Minute)
if err != nil {
cachePath := common.GetDiskCachePath()
if cachePath == "" {
cachePath = os.TempDir()
}
dir := filepath.Join(cachePath, "new-api-body-cache")
// 删除缓存目录
err := os.RemoveAll(dir)
if err != nil && !os.IsNotExist(err) {
common.ApiError(c, err)
return
}
// 重置统计
common.ResetDiskCacheStats()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "不活跃的磁盘缓存已清理",
"message": "磁盘缓存已清理",
})
}
@@ -171,8 +167,11 @@ func ForceGC(c *gin.Context) {
// getDiskCacheInfo 获取磁盘缓存目录信息
func getDiskCacheInfo() DiskCacheInfo {
// 使用统一的缓存目录
dir := common.GetDiskCacheDir()
cachePath := common.GetDiskCachePath()
if cachePath == "" {
cachePath = os.TempDir()
}
dir := filepath.Join(cachePath, "new-api-body-cache")
info := DiskCacheInfo{
Path: dir,

View File

@@ -1,16 +1,17 @@
//go:build !windows
package common
package controller
import (
"os"
"github.com/QuantumNous/new-api/common"
"golang.org/x/sys/unix"
)
// GetDiskSpaceInfo 获取缓存目录所在磁盘的空间信息 (Unix/Linux/macOS)
func GetDiskSpaceInfo() DiskSpaceInfo {
cachePath := GetDiskCachePath()
// getDiskSpaceInfo 获取缓存目录所在磁盘的空间信息 (Unix/Linux/macOS)
func getDiskSpaceInfo() DiskSpaceInfo {
cachePath := common.GetDiskCachePath()
if cachePath == "" {
cachePath = os.TempDir()
}

View File

@@ -1,16 +1,18 @@
//go:build windows
package common
package controller
import (
"os"
"syscall"
"unsafe"
"github.com/QuantumNous/new-api/common"
)
// GetDiskSpaceInfo 获取缓存目录所在磁盘的空间信息 (Windows)
func GetDiskSpaceInfo() DiskSpaceInfo {
cachePath := GetDiskCachePath()
// getDiskSpaceInfo 获取缓存目录所在磁盘的空间信息 (Windows)
func getDiskSpaceInfo() DiskSpaceInfo {
cachePath := common.GetDiskCachePath()
if cachePath == "" {
cachePath = os.TempDir()
}

View File

@@ -56,8 +56,7 @@ type upstreamResult struct {
func FetchUpstreamRatios(c *gin.Context) {
var req dto.UpstreamRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.SysError("failed to bind upstream request: " + err.Error())
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": "请求参数格式错误"})
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
return
}

View File

@@ -1,12 +1,12 @@
package controller
import (
"errors"
"net/http"
"strconv"
"unicode/utf8"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
@@ -66,19 +66,28 @@ func AddRedemption(c *gin.Context) {
return
}
if utf8.RuneCountInString(redemption.Name) == 0 || utf8.RuneCountInString(redemption.Name) > 20 {
common.ApiErrorI18n(c, i18n.MsgRedemptionNameLength)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "兑换码名称长度必须在1-20之间",
})
return
}
if redemption.Count <= 0 {
common.ApiErrorI18n(c, i18n.MsgRedemptionCountPositive)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "兑换码个数必须大于0",
})
return
}
if redemption.Count > 100 {
common.ApiErrorI18n(c, i18n.MsgRedemptionCountMax)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "一次兑换码批量生成的个数不能大于 100",
})
return
}
if valid, msg := validateExpiredTime(c, redemption.ExpiredTime); !valid {
c.JSON(http.StatusOK, gin.H{"success": false, "message": msg})
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
var keys []string
@@ -94,10 +103,9 @@ func AddRedemption(c *gin.Context) {
}
err = cleanRedemption.Insert()
if err != nil {
common.SysError("failed to insert redemption: " + err.Error())
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": i18n.T(c, i18n.MsgRedemptionCreateFailed),
"message": err.Error(),
"data": keys,
})
return
@@ -140,8 +148,8 @@ func UpdateRedemption(c *gin.Context) {
return
}
if statusOnly == "" {
if valid, msg := validateExpiredTime(c, redemption.ExpiredTime); !valid {
c.JSON(http.StatusOK, gin.H{"success": false, "message": msg})
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
// If you add more fields, please also update redemption.Update()
@@ -179,9 +187,9 @@ func DeleteInvalidRedemption(c *gin.Context) {
return
}
func validateExpiredTime(c *gin.Context, expired int64) (bool, string) {
func validateExpiredTime(expired int64) error {
if expired != 0 && expired < common.GetTimestamp() {
return false, i18n.T(c, i18n.MsgRedemptionExpireTimeInvalid)
return errors.New("过期时间不能早于当前时间")
}
return true, ""
return nil
}

View File

@@ -8,7 +8,6 @@ import (
"log"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
@@ -170,8 +169,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
// Only return quota if downstream failed and quota was actually pre-consumed
if newAPIError != nil {
newAPIError = service.NormalizeViolationFeeError(newAPIError)
if relayInfo.Billing != nil {
relayInfo.Billing.Refund(c)
if relayInfo.FinalPreConsumedQuota != 0 {
service.ReturnPreConsumedQuota(c, relayInfo)
}
service.ChargeViolationFeeIfNeeded(c, relayInfo, newAPIError)
}
@@ -374,12 +373,7 @@ func processChannelError(c *gin.Context, channelError types.ChannelError, err *t
}
service.AppendChannelAffinityAdminInfo(c, adminInfo)
other["admin_info"] = adminInfo
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
}
useTimeSeconds := int(time.Since(startTime).Seconds())
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, useTimeSeconds, false, userGroup, other)
model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveErrorWithStatusCode(), tokenId, 0, false, userGroup, other)
}
}

View File

@@ -133,6 +133,94 @@ func UniversalVerify(c *gin.Context) {
})
}
// GetVerificationStatus 获取验证状态
func GetVerificationStatus(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "未登录",
})
return
}
session := sessions.Default(c)
verifiedAtRaw := session.Get(SecureVerificationSessionKey)
if verifiedAtRaw == nil {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: false,
},
})
return
}
verifiedAt, ok := verifiedAtRaw.(int64)
if !ok {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: false,
},
})
return
}
elapsed := time.Now().Unix() - verifiedAt
if elapsed >= SecureVerificationTimeout {
// 验证已过期
session.Delete(SecureVerificationSessionKey)
_ = session.Save()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: false,
},
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: true,
ExpiresAt: verifiedAt + SecureVerificationTimeout,
},
})
}
// CheckSecureVerification 检查是否已通过安全验证
// 返回 true 表示验证有效false 表示需要重新验证
func CheckSecureVerification(c *gin.Context) bool {
session := sessions.Default(c)
verifiedAtRaw := session.Get(SecureVerificationSessionKey)
if verifiedAtRaw == nil {
return false
}
verifiedAt, ok := verifiedAtRaw.(int64)
if !ok {
return false
}
elapsed := time.Now().Unix() - verifiedAt
if elapsed >= SecureVerificationTimeout {
// 验证已过期,清除 session
session.Delete(SecureVerificationSessionKey)
_ = session.Save()
return false
}
return true
}
// PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
// 这是一个辅助函数,供 PasskeyVerifyFinish 调用
func PasskeyVerifyAndSetSession(c *gin.Context) {

View File

@@ -118,14 +118,6 @@ func AdminCreateSubscriptionPlan(c *gin.Context) {
common.ApiErrorMsg(c, "套餐标题不能为空")
return
}
if req.Plan.PriceAmount < 0 {
common.ApiErrorMsg(c, "价格不能为负数")
return
}
if req.Plan.PriceAmount > 9999 {
common.ApiErrorMsg(c, "价格不能超过9999")
return
}
if req.Plan.Currency == "" {
req.Plan.Currency = "USD"
}
@@ -180,14 +172,6 @@ func AdminUpdateSubscriptionPlan(c *gin.Context) {
common.ApiErrorMsg(c, "套餐标题不能为空")
return
}
if req.Plan.PriceAmount < 0 {
common.ApiErrorMsg(c, "价格不能为负数")
return
}
if req.Plan.PriceAmount > 9999 {
common.ApiErrorMsg(c, "价格不能超过9999")
return
}
req.Plan.Id = id
if req.Plan.Currency == "" {
req.Plan.Currency = "USD"

View File

@@ -108,35 +108,25 @@ func SubscriptionRequestEpay(c *gin.Context) {
common.ApiErrorMsg(c, "拉起支付失败")
return
}
c.JSON(http.StatusOK, gin.H{"message": "success", "data": params, "url": uri})
common.ApiSuccess(c, gin.H{"data": params, "url": uri})
}
func SubscriptionEpayNotify(c *gin.Context) {
var params map[string]string
if c.Request.Method == "POST" {
// POST 请求:从 POST body 解析参数
if err := c.Request.ParseForm(); err != nil {
_, _ = c.Writer.Write([]byte("fail"))
return
}
params = lo.Reduce(lo.Keys(c.Request.PostForm), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.PostForm.Get(t)
return r
}, map[string]string{})
} else {
// GET 请求:从 URL Query 解析参数
if err := c.Request.ParseForm(); err != nil {
_, _ = c.Writer.Write([]byte("fail"))
return
}
params := lo.Reduce(lo.Keys(c.Request.PostForm), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.PostForm.Get(t)
return r
}, map[string]string{})
if len(params) == 0 {
params = lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
}
if len(params) == 0 {
_, _ = c.Writer.Write([]byte("fail"))
return
}
client := GetEpayClient()
if client == nil {
_, _ = c.Writer.Write([]byte("fail"))
@@ -167,31 +157,21 @@ func SubscriptionEpayNotify(c *gin.Context) {
// SubscriptionEpayReturn handles browser return after payment.
// It verifies the payload and completes the order, then redirects to console.
func SubscriptionEpayReturn(c *gin.Context) {
var params map[string]string
if c.Request.Method == "POST" {
// POST 请求:从 POST body 解析参数
if err := c.Request.ParseForm(); err != nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/subscription?pay=fail")
return
}
params = lo.Reduce(lo.Keys(c.Request.PostForm), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.PostForm.Get(t)
return r
}, map[string]string{})
} else {
// GET 请求:从 URL Query 解析参数
if err := c.Request.ParseForm(); err != nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/subscription?pay=fail")
return
}
params := lo.Reduce(lo.Keys(c.Request.PostForm), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.PostForm.Get(t)
return r
}, map[string]string{})
if len(params) == 0 {
params = lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
}
if len(params) == 0 {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/subscription?pay=fail")
return
}
client := GetEpayClient()
if client == nil {
c.Redirect(http.StatusFound, system_setting.ServerAddress+"/console/subscription?pay=fail")

View File

@@ -7,9 +7,7 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
@@ -33,17 +31,16 @@ func SearchTokens(c *gin.Context) {
userId := c.GetInt("id")
keyword := c.Query("keyword")
token := c.Query("token")
pageInfo := common.GetPageQuery(c)
tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
tokens, err := model.SearchUserTokens(userId, keyword, token)
if err != nil {
common.ApiError(c, err)
return
}
pageInfo.SetTotal(int(total))
pageInfo.SetItems(tokens)
common.ApiSuccess(c, pageInfo)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": tokens,
})
return
}
@@ -110,8 +107,10 @@ func GetTokenUsage(c *gin.Context) {
token, err := model.GetTokenByKey(strings.TrimPrefix(tokenKey, "sk-"), false)
if err != nil {
common.SysError("failed to get token by key: " + err.Error())
common.ApiErrorI18n(c, i18n.MsgTokenGetInfoFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
@@ -145,38 +144,36 @@ func AddToken(c *gin.Context) {
return
}
if len(token.Name) > 50 {
common.ApiErrorI18n(c, i18n.MsgTokenNameTooLong)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
})
return
}
// 非无限额度时,检查额度值是否超出有效范围
if !token.UnlimitedQuota {
if token.RemainQuota < 0 {
common.ApiErrorI18n(c, i18n.MsgTokenQuotaNegative)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "额度值不能为负数",
})
return
}
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
if token.RemainQuota > maxQuotaValue {
common.ApiErrorI18n(c, i18n.MsgTokenQuotaExceedMax, map[string]any{"Max": maxQuotaValue})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
})
return
}
}
// 检查用户令牌数量是否已达上限
maxTokens := operation_setting.GetMaxUserTokens()
count, err := model.CountUserTokens(c.GetInt("id"))
if err != nil {
common.ApiError(c, err)
return
}
if int(count) >= maxTokens {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
})
return
}
key, err := common.GenerateKey()
if err != nil {
common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成令牌失败",
})
common.SysLog("failed to generate token key: " + err.Error())
return
}
@@ -232,17 +229,26 @@ func UpdateToken(c *gin.Context) {
return
}
if len(token.Name) > 50 {
common.ApiErrorI18n(c, i18n.MsgTokenNameTooLong)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
})
return
}
if !token.UnlimitedQuota {
if token.RemainQuota < 0 {
common.ApiErrorI18n(c, i18n.MsgTokenQuotaNegative)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "额度值不能为负数",
})
return
}
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
if token.RemainQuota > maxQuotaValue {
common.ApiErrorI18n(c, i18n.MsgTokenQuotaExceedMax, map[string]any{"Max": maxQuotaValue})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
})
return
}
}
@@ -253,11 +259,17 @@ func UpdateToken(c *gin.Context) {
}
if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
common.ApiErrorI18n(c, i18n.MsgTokenExpiredCannotEnable)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
})
return
}
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
common.ApiErrorI18n(c, i18n.MsgTokenExhaustedCannotEable)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
})
return
}
}
@@ -294,7 +306,10 @@ type TokenBatch struct {
func DeleteTokenBatch(c *gin.Context) {
tokenBatch := TokenBatch{}
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
userId := c.GetInt("id")

View File

@@ -228,32 +228,21 @@ func UnlockOrder(tradeNo string) {
}
func EpayNotify(c *gin.Context) {
var params map[string]string
if c.Request.Method == "POST" {
// POST 请求:从 POST body 解析参数
if err := c.Request.ParseForm(); err != nil {
log.Println("易支付回调POST解析失败:", err)
_, _ = c.Writer.Write([]byte("fail"))
return
}
params = lo.Reduce(lo.Keys(c.Request.PostForm), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.PostForm.Get(t)
return r
}, map[string]string{})
} else {
// GET 请求:从 URL Query 解析参数
if err := c.Request.ParseForm(); err != nil {
log.Println("易支付回调解析失败:", err)
_, _ = c.Writer.Write([]byte("fail"))
return
}
params := lo.Reduce(lo.Keys(c.Request.PostForm), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.PostForm.Get(t)
return r
}, map[string]string{})
if len(params) == 0 {
params = lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
}
if len(params) == 0 {
log.Println("易支付回调参数为空")
_, _ = c.Writer.Write([]byte("fail"))
return
}
client := GetEpayClient()
if client == nil {
log.Println("易支付回调失败 未找到配置信息")

View File

@@ -2,7 +2,6 @@ package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@@ -12,7 +11,6 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/service"
@@ -31,19 +29,28 @@ type LoginRequest struct {
func Login(c *gin.Context) {
if !common.PasswordLoginEnabled {
common.ApiErrorI18n(c, i18n.MsgUserPasswordLoginDisabled)
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录",
"success": false,
})
return
}
var loginRequest LoginRequest
err := json.NewDecoder(c.Request.Body).Decode(&loginRequest)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"message": "无效的参数",
"success": false,
})
return
}
username := loginRequest.Username
password := loginRequest.Password
if username == "" || password == "" {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"message": "无效的参数",
"success": false,
})
return
}
user := model.User{
@@ -67,12 +74,15 @@ func Login(c *gin.Context) {
session.Set("pending_user_id", user.Id)
err := session.Save()
if err != nil {
common.ApiErrorI18n(c, i18n.MsgUserSessionSaveFailed)
c.JSON(http.StatusOK, gin.H{
"message": "无法保存会话信息,请重试",
"success": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
"message": i18n.T(c, i18n.MsgUserRequire2FA),
"message": "请输入两步验证码",
"success": true,
"data": map[string]interface{}{
"require_2fa": true,
@@ -94,7 +104,10 @@ func setupLogin(user *model.User, c *gin.Context) {
session.Set("group", user.Group)
err := session.Save()
if err != nil {
common.ApiErrorI18n(c, i18n.MsgUserSessionSaveFailed)
c.JSON(http.StatusOK, gin.H{
"message": "无法保存会话信息,请重试",
"success": false,
})
return
}
c.JSON(http.StatusOK, gin.H{
@@ -130,41 +143,65 @@ func Logout(c *gin.Context) {
func Register(c *gin.Context) {
if !common.RegisterEnabled {
common.ApiErrorI18n(c, i18n.MsgUserRegisterDisabled)
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册",
"success": false,
})
return
}
if !common.PasswordRegisterEnabled {
common.ApiErrorI18n(c, i18n.MsgUserPasswordRegisterDisabled)
c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false,
})
return
}
var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
if err := common.Validate.Struct(&user); err != nil {
common.ApiErrorI18n(c, i18n.MsgUserInputInvalid, map[string]any{"Error": err.Error()})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "输入不合法 " + err.Error(),
})
return
}
if common.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" {
common.ApiErrorI18n(c, i18n.MsgUserEmailVerificationRequired)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员开启了邮箱验证,请输入邮箱地址和验证码",
})
return
}
if !common.VerifyCodeWithKey(user.Email, user.VerificationCode, common.EmailVerificationPurpose) {
common.ApiErrorI18n(c, i18n.MsgUserVerificationCodeError)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码错误或已过期",
})
return
}
}
exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgDatabaseError)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "数据库错误,请稍后重试",
})
common.SysLog(fmt.Sprintf("CheckUserExistOrDeleted error: %v", err))
return
}
if exist {
common.ApiErrorI18n(c, i18n.MsgUserExists)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户名已存在,或已注销",
})
return
}
affCode := user.AffCode // this code is the inviter's code, not the user's own code
@@ -187,14 +224,20 @@ func Register(c *gin.Context) {
// 获取插入后的用户ID
var insertedUser model.User
if err := model.DB.Where("username = ?", cleanUser.Username).First(&insertedUser).Error; err != nil {
common.ApiErrorI18n(c, i18n.MsgUserRegisterFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户注册失败或用户ID获取失败",
})
return
}
// 生成默认令牌
if constant.GenerateDefaultToken {
key, err := common.GenerateKey()
if err != nil {
common.ApiErrorI18n(c, i18n.MsgUserDefaultTokenFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成默认令牌失败",
})
common.SysLog("failed to generate token key: " + err.Error())
return
}
@@ -214,7 +257,10 @@ func Register(c *gin.Context) {
token.Group = "auto"
}
if err := token.Insert(); err != nil {
common.ApiErrorI18n(c, i18n.MsgCreateDefaultTokenErr)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "创建默认令牌失败",
})
return
}
}
@@ -270,7 +316,10 @@ func GetUser(c *gin.Context) {
}
myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionSameLevel)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权获取同级或更高等级用户的信息",
})
return
}
c.JSON(http.StatusOK, gin.H{
@@ -292,14 +341,20 @@ func GenerateAccessToken(c *gin.Context) {
randI := common.GetRandomInt(4)
key, err := common.GenerateRandomKey(29 + randI)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgGenerateFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "生成失败",
})
common.SysLog("failed to generate key: " + err.Error())
return
}
user.SetAccessToken(key)
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
common.ApiErrorI18n(c, i18n.MsgUuidDuplicate)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请重试,系统生成的 UUID 竟然重复了!",
})
return
}
@@ -334,10 +389,16 @@ func TransferAffQuota(c *gin.Context) {
}
err = user.TransferAffQuotaToQuota(tran.Quota)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgUserTransferFailed, map[string]any{"Error": err.Error()})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "划转失败 " + err.Error(),
})
return
}
common.ApiSuccessI18n(c, i18n.MsgUserTransferSuccess, nil)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "划转成功",
})
}
func GetAffCode(c *gin.Context) {
@@ -540,14 +601,20 @@ func UpdateUser(c *gin.Context) {
var updatedUser model.User
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
if err != nil || updatedUser.Id == 0 {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
if updatedUser.Password == "" {
updatedUser.Password = "$I_LOVE_U" // make Validator happy :)
}
if err := common.Validate.Struct(&updatedUser); err != nil {
common.ApiErrorI18n(c, i18n.MsgUserInputInvalid, map[string]any{"Error": err.Error()})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "输入不合法 " + err.Error(),
})
return
}
originUser, err := model.GetUserById(updatedUser.Id, false)
@@ -557,11 +624,17 @@ func UpdateUser(c *gin.Context) {
}
myRole := c.GetInt("role")
if myRole <= originUser.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionHigherLevel)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
})
return
}
if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserCannotCreateHigherLevel)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
})
return
}
if updatedUser.Password == "$I_LOVE_U" {
@@ -586,12 +659,15 @@ func UpdateSelf(c *gin.Context) {
var requestData map[string]interface{}
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 检查是否是用户设置更新请求 (sidebar_modules 或 language)
if sidebarModules, sidebarExists := requestData["sidebar_modules"]; sidebarExists {
// 检查是否是sidebar_modules更新请求
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, false)
if err != nil {
@@ -610,39 +686,17 @@ func UpdateSelf(c *gin.Context) {
// 保存更新后的设置
user.SetSetting(currentSetting)
if err := user.Update(false); err != nil {
common.ApiErrorI18n(c, i18n.MsgUpdateFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
common.ApiSuccessI18n(c, i18n.MsgUpdateSuccess, nil)
return
}
// 检查是否是语言偏好更新请求
if language, langExists := requestData["language"]; langExists {
userId := c.GetInt("id")
user, err := model.GetUserById(userId, false)
if err != nil {
common.ApiError(c, err)
return
}
// 获取当前用户设置
currentSetting := user.GetSetting()
// 更新language字段
if langStr, ok := language.(string); ok {
currentSetting.Language = langStr
}
// 保存更新后的设置
user.SetSetting(currentSetting)
if err := user.Update(false); err != nil {
common.ApiErrorI18n(c, i18n.MsgUpdateFailed)
return
}
common.ApiSuccessI18n(c, i18n.MsgUpdateSuccess, nil)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置更新成功",
})
return
}
@@ -650,12 +704,18 @@ func UpdateSelf(c *gin.Context) {
var user model.User
requestDataBytes, err := json.Marshal(requestData)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
err = json.Unmarshal(requestDataBytes, &user)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
@@ -663,7 +723,10 @@ func UpdateSelf(c *gin.Context) {
user.Password = "$I_LOVE_U" // make Validator happy :)
}
if err := common.Validate.Struct(&user); err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidInput)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "输入不合法 " + err.Error(),
})
return
}
@@ -727,7 +790,10 @@ func DeleteUser(c *gin.Context) {
}
myRole := c.GetInt("role")
if myRole <= originUser.Role {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionHigherLevel)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权删除同权限等级或更高权限等级的用户",
})
return
}
err = model.HardDeleteUserById(id)
@@ -745,7 +811,10 @@ func DeleteSelf(c *gin.Context) {
user, _ := model.GetUserById(id, false)
if user.Role == common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserCannotDeleteRootUser)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除超级管理员账户",
})
return
}
@@ -766,11 +835,17 @@ func CreateUser(c *gin.Context) {
err := json.NewDecoder(c.Request.Body).Decode(&user)
user.Username = strings.TrimSpace(user.Username)
if err != nil || user.Username == "" || user.Password == "" {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
if err := common.Validate.Struct(&user); err != nil {
common.ApiErrorI18n(c, i18n.MsgUserInputInvalid, map[string]any{"Error": err.Error()})
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "输入不合法 " + err.Error(),
})
return
}
if user.DisplayName == "" {
@@ -778,7 +853,10 @@ func CreateUser(c *gin.Context) {
}
myRole := c.GetInt("role")
if user.Role >= myRole {
common.ApiErrorI18n(c, i18n.MsgUserCannotCreateHigherLevel)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法创建权限大于等于自己的用户",
})
return
}
// Even for admin users, we cannot fully trust them!
@@ -811,7 +889,10 @@ func ManageUser(c *gin.Context) {
err := json.NewDecoder(c.Request.Body).Decode(&req)
if err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
user := model.User{
@@ -820,26 +901,38 @@ func ManageUser(c *gin.Context) {
// Fill attributes
model.DB.Unscoped().Where(&user).First(&user)
if user.Id == 0 {
common.ApiErrorI18n(c, i18n.MsgUserNotExists)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户不存在",
})
return
}
myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserNoPermissionHigherLevel)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
})
return
}
switch req.Action {
case "disable":
user.Status = common.UserStatusDisabled
if user.Role == common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserCannotDisableRootUser)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法禁用超级管理员用户",
})
return
}
case "enable":
user.Status = common.UserStatusEnabled
case "delete":
if user.Role == common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserCannotDeleteRootUser)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法删除超级管理员用户",
})
return
}
if err := user.Delete(); err != nil {
@@ -851,21 +944,33 @@ func ManageUser(c *gin.Context) {
}
case "promote":
if myRole != common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserAdminCannotPromote)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "普通管理员用户无法提升其他用户为管理员",
})
return
}
if user.Role >= common.RoleAdminUser {
common.ApiErrorI18n(c, i18n.MsgUserAlreadyAdmin)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是管理员",
})
return
}
user.Role = common.RoleAdminUser
case "demote":
if user.Role == common.RoleRootUser {
common.ApiErrorI18n(c, i18n.MsgUserCannotDemoteRootUser)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法降级超级管理员用户",
})
return
}
if user.Role == common.RoleCommonUser {
common.ApiErrorI18n(c, i18n.MsgUserAlreadyCommon)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是普通用户",
})
return
}
user.Role = common.RoleCommonUser
@@ -891,7 +996,10 @@ func EmailBind(c *gin.Context) {
email := c.Query("email")
code := c.Query("code")
if !common.VerifyCodeWithKey(email, code, common.EmailVerificationPurpose) {
common.ApiErrorI18n(c, i18n.MsgUserVerificationCodeError)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "验证码错误或已过期",
})
return
}
session := sessions.Default(c)
@@ -967,7 +1075,10 @@ func TopUp(c *gin.Context) {
id := c.GetInt("id")
lock := getTopUpLock(id)
if !lock.TryLock() {
common.ApiErrorI18n(c, i18n.MsgUserTopUpProcessing)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "充值处理中,请稍后重试",
})
return
}
defer lock.Unlock()
@@ -979,10 +1090,6 @@ func TopUp(c *gin.Context) {
}
quota, err := model.Redeem(req.Key, id)
if err != nil {
if errors.Is(err, model.ErrRedeemFailed) {
common.ApiErrorI18n(c, i18n.MsgRedeemFailed)
return
}
common.ApiError(c, err)
return
}
@@ -1010,31 +1117,46 @@ type UpdateUserSettingRequest struct {
func UpdateUserSetting(c *gin.Context) {
var req UpdateUserSettingRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.ApiErrorI18n(c, i18n.MsgInvalidParams)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的参数",
})
return
}
// 验证预警类型
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark && req.QuotaWarningType != dto.NotifyTypeGotify {
common.ApiErrorI18n(c, i18n.MsgSettingInvalidType)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的预警类型",
})
return
}
// 验证预警阈值
if req.QuotaWarningThreshold <= 0 {
common.ApiErrorI18n(c, i18n.MsgQuotaThresholdGtZero)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "预警阈值必须大于0",
})
return
}
// 如果是webhook类型,验证webhook地址
if req.QuotaWarningType == dto.NotifyTypeWebhook {
if req.WebhookUrl == "" {
common.ApiErrorI18n(c, i18n.MsgSettingWebhookEmpty)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Webhook地址不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.WebhookUrl); err != nil {
common.ApiErrorI18n(c, i18n.MsgSettingWebhookInvalid)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Webhook地址",
})
return
}
}
@@ -1043,7 +1165,10 @@ func UpdateUserSetting(c *gin.Context) {
if req.QuotaWarningType == dto.NotifyTypeEmail && req.NotificationEmail != "" {
// 验证邮箱格式
if !strings.Contains(req.NotificationEmail, "@") {
common.ApiErrorI18n(c, i18n.MsgSettingEmailInvalid)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的邮箱地址",
})
return
}
}
@@ -1051,17 +1176,26 @@ func UpdateUserSetting(c *gin.Context) {
// 如果是Bark类型验证Bark URL
if req.QuotaWarningType == dto.NotifyTypeBark {
if req.BarkUrl == "" {
common.ApiErrorI18n(c, i18n.MsgSettingBarkUrlEmpty)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Bark推送URL不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
common.ApiErrorI18n(c, i18n.MsgSettingBarkUrlInvalid)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Bark推送URL",
})
return
}
// 检查是否是HTTP或HTTPS
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
common.ApiErrorI18n(c, i18n.MsgSettingUrlMustHttp)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Bark推送URL必须以http://或https://开头",
})
return
}
}
@@ -1069,21 +1203,33 @@ func UpdateUserSetting(c *gin.Context) {
// 如果是Gotify类型验证Gotify URL和Token
if req.QuotaWarningType == dto.NotifyTypeGotify {
if req.GotifyUrl == "" {
common.ApiErrorI18n(c, i18n.MsgSettingGotifyUrlEmpty)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Gotify服务器地址不能为空",
})
return
}
if req.GotifyToken == "" {
common.ApiErrorI18n(c, i18n.MsgSettingGotifyTokenEmpty)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Gotify令牌不能为空",
})
return
}
// 验证URL格式
if _, err := url.ParseRequestURI(req.GotifyUrl); err != nil {
common.ApiErrorI18n(c, i18n.MsgSettingGotifyUrlInvalid)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的Gotify服务器地址",
})
return
}
// 检查是否是HTTP或HTTPS
if !strings.HasPrefix(req.GotifyUrl, "https://") && !strings.HasPrefix(req.GotifyUrl, "http://") {
common.ApiErrorI18n(c, i18n.MsgSettingUrlMustHttp)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "Gotify服务器地址必须以http://或https://开头",
})
return
}
}
@@ -1136,9 +1282,15 @@ func UpdateUserSetting(c *gin.Context) {
// 更新用户设置
user.SetSetting(settings)
if err := user.Update(false); err != nil {
common.ApiErrorI18n(c, i18n.MsgUpdateFailed)
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "更新设置失败: " + err.Error(),
})
return
}
common.ApiSuccessI18n(c, i18n.MsgSettingSaved, nil)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "设置已更新",
})
}

View File

@@ -214,14 +214,6 @@ type ClaudeRequest struct {
ServiceTier string `json:"service_tier,omitempty"`
}
// createClaudeFileSource 根据数据内容创建正确类型的 FileSource
func createClaudeFileSource(data string) *types.FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return types.NewURLFileSource(data)
}
return types.NewBase64FileSource(data, "")
}
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta = types.TokenCountMeta{
TokenType: types.TokenTypeTokenizer,
@@ -251,10 +243,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage,
Source: createClaudeFileSource(data),
})
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
}
}
}
@@ -286,10 +275,7 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
data = common.Interface2String(media.Source.Data)
}
if data != "" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage,
Source: createClaudeFileSource(data),
})
fileMeta = append(fileMeta, &types.FileMeta{FileType: types.FileTypeImage, OriginData: data})
}
}
case "tool_use":

View File

@@ -64,14 +64,6 @@ type LatLng struct {
Longitude *float64 `json:"longitude,omitempty"`
}
// createGeminiFileSource 根据数据内容创建正确类型的 FileSource
func createGeminiFileSource(data string, mimeType string) *types.FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return types.NewURLFileSource(data)
}
return types.NewBase64FileSource(data, mimeType)
}
func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
var files []*types.FileMeta = make([]*types.FileMeta, 0)
@@ -88,23 +80,27 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
inputTexts = append(inputTexts, part.Text)
}
if part.InlineData != nil && part.InlineData.Data != "" {
mimeType := part.InlineData.MimeType
source := createGeminiFileSource(part.InlineData.Data, mimeType)
var fileType types.FileType
if strings.HasPrefix(mimeType, "image/") {
fileType = types.FileTypeImage
} else if strings.HasPrefix(mimeType, "audio/") {
fileType = types.FileTypeAudio
} else if strings.HasPrefix(mimeType, "video/") {
fileType = types.FileTypeVideo
if strings.HasPrefix(part.InlineData.MimeType, "image/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeImage,
OriginData: part.InlineData.Data,
})
} else if strings.HasPrefix(part.InlineData.MimeType, "audio/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeAudio,
OriginData: part.InlineData.Data,
})
} else if strings.HasPrefix(part.InlineData.MimeType, "video/") {
files = append(files, &types.FileMeta{
FileType: types.FileTypeVideo,
OriginData: part.InlineData.Data,
})
} else {
fileType = types.FileTypeFile
files = append(files, &types.FileMeta{
FileType: types.FileTypeFile,
OriginData: part.InlineData.Data,
})
}
files = append(files, &types.FileMeta{
FileType: fileType,
Source: source,
MimeType: mimeType,
})
}
}
}

View File

@@ -101,14 +101,6 @@ type GeneralOpenAIRequest struct {
SearchMode string `json:"search_mode,omitempty"`
}
// createFileSource 根据数据内容创建正确类型的 FileSource
func createFileSource(data string) *types.FileSource {
if strings.HasPrefix(data, "http://") || strings.HasPrefix(data, "https://") {
return types.NewURLFileSource(data)
}
return types.NewBase64FileSource(data, "")
}
func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
var tokenCountMeta types.TokenCountMeta
var texts = make([]string, 0)
@@ -152,40 +144,42 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
for _, m := range arrayContent {
if m.Type == ContentTypeImageURL {
imageUrl := m.GetImageMedia()
if imageUrl != nil && imageUrl.Url != "" {
source := createFileSource(imageUrl.Url)
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage,
Source: source,
Detail: imageUrl.Detail,
})
if imageUrl != nil {
if imageUrl.Url != "" {
meta := &types.FileMeta{
FileType: types.FileTypeImage,
}
meta.OriginData = imageUrl.Url
meta.Detail = imageUrl.Detail
fileMeta = append(fileMeta, meta)
}
}
} else if m.Type == ContentTypeInputAudio {
inputAudio := m.GetInputAudio()
if inputAudio != nil && inputAudio.Data != "" {
source := createFileSource(inputAudio.Data)
fileMeta = append(fileMeta, &types.FileMeta{
if inputAudio != nil {
meta := &types.FileMeta{
FileType: types.FileTypeAudio,
Source: source,
})
}
meta.OriginData = inputAudio.Data
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeFile {
file := m.GetFile()
if file != nil && file.FileData != "" {
source := createFileSource(file.FileData)
fileMeta = append(fileMeta, &types.FileMeta{
if file != nil {
meta := &types.FileMeta{
FileType: types.FileTypeFile,
Source: source,
})
}
meta.OriginData = file.FileData
fileMeta = append(fileMeta, meta)
}
} else if m.Type == ContentTypeVideoUrl {
videoUrl := m.GetVideoUrl()
if videoUrl != nil && videoUrl.Url != "" {
source := createFileSource(videoUrl.Url)
fileMeta = append(fileMeta, &types.FileMeta{
meta := &types.FileMeta{
FileType: types.FileTypeVideo,
Source: source,
})
}
meta.OriginData = videoUrl.Url
fileMeta = append(fileMeta, meta)
}
} else {
texts = append(texts, m.Text)
@@ -839,16 +833,16 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
if input.Type == "input_image" {
if input.ImageUrl != "" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeImage,
Source: createFileSource(input.ImageUrl),
Detail: input.Detail,
FileType: types.FileTypeImage,
OriginData: input.ImageUrl,
Detail: input.Detail,
})
}
} else if input.Type == "input_file" {
if input.FileUrl != "" {
fileMeta = append(fileMeta, &types.FileMeta{
FileType: types.FileTypeFile,
Source: createFileSource(input.FileUrl),
FileType: types.FileTypeFile,
OriginData: input.FileUrl,
})
}
} else {

View File

@@ -352,11 +352,6 @@ type ResponsesOutputContent struct {
Annotations []interface{} `json:"annotations"`
}
type ResponsesReasoningSummaryPart struct {
Type string `json:"type"`
Text string `json:"text"`
}
const (
BuildInToolWebSearchPreview = "web_search_preview"
BuildInToolFileSearch = "file_search"
@@ -379,11 +374,8 @@ type ResponsesStreamResponse struct {
Item *ResponsesOutput `json:"item,omitempty"`
// - response.function_call_arguments.delta
// - response.function_call_arguments.done
OutputIndex *int `json:"output_index,omitempty"`
ContentIndex *int `json:"content_index,omitempty"`
SummaryIndex *int `json:"summary_index,omitempty"`
ItemID string `json:"item_id,omitempty"`
Part *ResponsesReasoningSummaryPart `json:"part,omitempty"`
OutputIndex *int `json:"output_index,omitempty"`
ItemID string `json:"item_id,omitempty"`
}
// GetOpenAIError 从动态错误类型中提取OpenAIError结构

View File

@@ -14,7 +14,6 @@ type UserSetting struct {
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
BillingPreference string `json:"billing_preference,omitempty"` // BillingPreference 扣费策略(订阅/钱包)
Language string `json:"language,omitempty"` // Language 用户语言偏好 (zh, en)
}
var (

11
go.mod
View File

@@ -32,10 +32,8 @@ require (
github.com/jinzhu/copier v0.4.0
github.com/joho/godotenv v1.5.1
github.com/mewkiz/flac v1.0.13
github.com/nicksnyder/go-i18n/v2 v2.6.1
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.5.0
github.com/samber/hot v0.11.0
github.com/samber/lo v1.52.0
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/shopspring/decimal v1.4.0
@@ -50,10 +48,7 @@ require (
golang.org/x/crypto v0.45.0
golang.org/x/image v0.23.0
golang.org/x/net v0.47.0
golang.org/x/sync v0.19.0
golang.org/x/sys v0.38.0
golang.org/x/text v0.32.0
gopkg.in/yaml.v3 v3.0.1
golang.org/x/sync v0.18.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/gorm v1.25.2
@@ -120,6 +115,7 @@ require (
github.com/prometheus/procfs v0.15.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/samber/go-singleflightx v0.3.2 // indirect
github.com/samber/hot v0.11.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
@@ -131,7 +127,10 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
golang.org/x/arch v0.21.0 // indirect
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.66.10 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect

7
go.sum
View File

@@ -213,8 +213,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/nicksnyder/go-i18n/v2 v2.6.1 h1:JDEJraFsQE17Dut9HFDHzCoAWGEQJom5s0TRd17NIEQ=
github.com/nicksnyder/go-i18n/v2 v2.6.1/go.mod h1:Vee0/9RD3Quc/NmwEjzzD7VTZ+Ir7QbXocrkhOzmUKA=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
@@ -331,8 +329,6 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -353,12 +349,9 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

View File

@@ -1,227 +0,0 @@
package i18n
import (
"embed"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/nicksnyder/go-i18n/v2/i18n"
"golang.org/x/text/language"
"gopkg.in/yaml.v3"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
)
const (
LangZh = "zh"
LangEn = "en"
DefaultLang = LangEn // Fallback to English if language not supported
)
//go:embed locales/*.yaml
var localeFS embed.FS
var (
bundle *i18n.Bundle
localizers = make(map[string]*i18n.Localizer)
mu sync.RWMutex
initOnce sync.Once
)
// Init initializes the i18n bundle and loads all translation files
func Init() error {
var initErr error
initOnce.Do(func() {
bundle = i18n.NewBundle(language.Chinese)
bundle.RegisterUnmarshalFunc("yaml", yaml.Unmarshal)
// Load embedded translation files
files := []string{"locales/zh.yaml", "locales/en.yaml"}
for _, file := range files {
_, err := bundle.LoadMessageFileFS(localeFS, file)
if err != nil {
initErr = err
return
}
}
// Pre-create localizers for supported languages
localizers[LangZh] = i18n.NewLocalizer(bundle, LangZh)
localizers[LangEn] = i18n.NewLocalizer(bundle, LangEn)
// Set the TranslateMessage function in common package
common.TranslateMessage = T
})
return initErr
}
// GetLocalizer returns a localizer for the specified language
func GetLocalizer(lang string) *i18n.Localizer {
lang = normalizeLang(lang)
mu.RLock()
loc, ok := localizers[lang]
mu.RUnlock()
if ok {
return loc
}
// Create new localizer for unknown language (fallback to default)
mu.Lock()
defer mu.Unlock()
// Double-check after acquiring write lock
if loc, ok = localizers[lang]; ok {
return loc
}
loc = i18n.NewLocalizer(bundle, lang, DefaultLang)
localizers[lang] = loc
return loc
}
// T translates a message key using the language from gin context
func T(c *gin.Context, key string, args ...map[string]any) string {
lang := GetLangFromContext(c)
return Translate(lang, key, args...)
}
// Translate translates a message key for the specified language
func Translate(lang, key string, args ...map[string]any) string {
loc := GetLocalizer(lang)
config := &i18n.LocalizeConfig{
MessageID: key,
}
if len(args) > 0 && args[0] != nil {
config.TemplateData = args[0]
}
msg, err := loc.Localize(config)
if err != nil {
// Return key as fallback if translation not found
return key
}
return msg
}
// userLangLoaderFunc is a function that loads user language from database/cache
// It's set by the model package to avoid circular imports
var userLangLoaderFunc func(userId int) string
// SetUserLangLoader sets the function to load user language (called from model package)
func SetUserLangLoader(loader func(userId int) string) {
userLangLoaderFunc = loader
}
// GetLangFromContext extracts the language setting from gin context
// It checks multiple sources in priority order:
// 1. User settings (ContextKeyUserSetting) - if already loaded (e.g., by TokenAuth)
// 2. Lazy load user language from cache/DB using user ID
// 3. Language set by middleware (ContextKeyLanguage) - from Accept-Language header
// 4. Default language (English)
func GetLangFromContext(c *gin.Context) string {
if c == nil {
return DefaultLang
}
// 1. Try to get language from user settings (if already loaded by TokenAuth or other middleware)
if userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting); ok {
if userSetting.Language != "" {
normalized := normalizeLang(userSetting.Language)
if IsSupported(normalized) {
return normalized
}
}
}
// 2. Lazy load user language using user ID (for session-based auth where full settings aren't loaded)
if userLangLoaderFunc != nil {
if userId, exists := c.Get("id"); exists {
if uid, ok := userId.(int); ok && uid > 0 {
lang := userLangLoaderFunc(uid)
if lang != "" {
normalized := normalizeLang(lang)
if IsSupported(normalized) {
return normalized
}
}
}
}
}
// 3. Try to get language from context (set by I18n middleware from Accept-Language)
if lang := c.GetString(string(constant.ContextKeyLanguage)); lang != "" {
normalized := normalizeLang(lang)
if IsSupported(normalized) {
return normalized
}
}
// 4. Try Accept-Language header directly (fallback if middleware didn't run)
if acceptLang := c.GetHeader("Accept-Language"); acceptLang != "" {
lang := ParseAcceptLanguage(acceptLang)
if IsSupported(lang) {
return lang
}
}
return DefaultLang
}
// ParseAcceptLanguage parses the Accept-Language header and returns the preferred language
func ParseAcceptLanguage(header string) string {
if header == "" {
return DefaultLang
}
// Simple parsing: take the first language tag
parts := strings.Split(header, ",")
if len(parts) == 0 {
return DefaultLang
}
// Get the first language and remove quality value
firstLang := strings.TrimSpace(parts[0])
if idx := strings.Index(firstLang, ";"); idx > 0 {
firstLang = firstLang[:idx]
}
return normalizeLang(firstLang)
}
// normalizeLang normalizes language code to supported format
func normalizeLang(lang string) string {
lang = strings.ToLower(strings.TrimSpace(lang))
// Handle common variations
switch {
case strings.HasPrefix(lang, "zh"):
return LangZh
case strings.HasPrefix(lang, "en"):
return LangEn
default:
return DefaultLang
}
}
// SupportedLanguages returns a list of supported language codes
func SupportedLanguages() []string {
return []string{LangZh, LangEn}
}
// IsSupported checks if a language code is supported
func IsSupported(lang string) bool {
lang = normalizeLang(lang)
for _, supported := range SupportedLanguages() {
if lang == supported {
return true
}
}
return false
}

View File

@@ -1,300 +0,0 @@
package i18n
// Message keys for i18n translations
// Use these constants instead of hardcoded strings
// Common error messages
const (
MsgInvalidParams = "common.invalid_params"
MsgDatabaseError = "common.database_error"
MsgRetryLater = "common.retry_later"
MsgGenerateFailed = "common.generate_failed"
MsgNotFound = "common.not_found"
MsgUnauthorized = "common.unauthorized"
MsgForbidden = "common.forbidden"
MsgInvalidId = "common.invalid_id"
MsgIdEmpty = "common.id_empty"
MsgFeatureDisabled = "common.feature_disabled"
MsgOperationSuccess = "common.operation_success"
MsgOperationFailed = "common.operation_failed"
MsgUpdateSuccess = "common.update_success"
MsgUpdateFailed = "common.update_failed"
MsgCreateSuccess = "common.create_success"
MsgCreateFailed = "common.create_failed"
MsgDeleteSuccess = "common.delete_success"
MsgDeleteFailed = "common.delete_failed"
MsgAlreadyExists = "common.already_exists"
MsgNameCannotBeEmpty = "common.name_cannot_be_empty"
)
// Token related messages
const (
MsgTokenNameTooLong = "token.name_too_long"
MsgTokenQuotaNegative = "token.quota_negative"
MsgTokenQuotaExceedMax = "token.quota_exceed_max"
MsgTokenGenerateFailed = "token.generate_failed"
MsgTokenGetInfoFailed = "token.get_info_failed"
MsgTokenExpiredCannotEnable = "token.expired_cannot_enable"
MsgTokenExhaustedCannotEable = "token.exhausted_cannot_enable"
MsgTokenInvalid = "token.invalid"
MsgTokenNotProvided = "token.not_provided"
MsgTokenExpired = "token.expired"
MsgTokenExhausted = "token.exhausted"
MsgTokenStatusUnavailable = "token.status_unavailable"
MsgTokenDbError = "token.db_error"
)
// Redemption related messages
const (
MsgRedemptionNameLength = "redemption.name_length"
MsgRedemptionCountPositive = "redemption.count_positive"
MsgRedemptionCountMax = "redemption.count_max"
MsgRedemptionCreateFailed = "redemption.create_failed"
MsgRedemptionInvalid = "redemption.invalid"
MsgRedemptionUsed = "redemption.used"
MsgRedemptionExpired = "redemption.expired"
MsgRedemptionFailed = "redemption.failed"
MsgRedemptionNotProvided = "redemption.not_provided"
MsgRedemptionExpireTimeInvalid = "redemption.expire_time_invalid"
)
// User related messages
const (
MsgUserPasswordLoginDisabled = "user.password_login_disabled"
MsgUserRegisterDisabled = "user.register_disabled"
MsgUserPasswordRegisterDisabled = "user.password_register_disabled"
MsgUserUsernameOrPasswordEmpty = "user.username_or_password_empty"
MsgUserUsernameOrPasswordError = "user.username_or_password_error"
MsgUserEmailOrPasswordEmpty = "user.email_or_password_empty"
MsgUserExists = "user.exists"
MsgUserNotExists = "user.not_exists"
MsgUserDisabled = "user.disabled"
MsgUserSessionSaveFailed = "user.session_save_failed"
MsgUserRequire2FA = "user.require_2fa"
MsgUserEmailVerificationRequired = "user.email_verification_required"
MsgUserVerificationCodeError = "user.verification_code_error"
MsgUserInputInvalid = "user.input_invalid"
MsgUserNoPermissionSameLevel = "user.no_permission_same_level"
MsgUserNoPermissionHigherLevel = "user.no_permission_higher_level"
MsgUserCannotCreateHigherLevel = "user.cannot_create_higher_level"
MsgUserCannotDeleteRootUser = "user.cannot_delete_root_user"
MsgUserCannotDisableRootUser = "user.cannot_disable_root_user"
MsgUserCannotDemoteRootUser = "user.cannot_demote_root_user"
MsgUserAlreadyAdmin = "user.already_admin"
MsgUserAlreadyCommon = "user.already_common"
MsgUserAdminCannotPromote = "user.admin_cannot_promote"
MsgUserOriginalPasswordError = "user.original_password_error"
MsgUserInviteQuotaInsufficient = "user.invite_quota_insufficient"
MsgUserTransferQuotaMinimum = "user.transfer_quota_minimum"
MsgUserTransferSuccess = "user.transfer_success"
MsgUserTransferFailed = "user.transfer_failed"
MsgUserTopUpProcessing = "user.topup_processing"
MsgUserRegisterFailed = "user.register_failed"
MsgUserDefaultTokenFailed = "user.default_token_failed"
MsgUserAffCodeEmpty = "user.aff_code_empty"
MsgUserEmailEmpty = "user.email_empty"
MsgUserGitHubIdEmpty = "user.github_id_empty"
MsgUserDiscordIdEmpty = "user.discord_id_empty"
MsgUserOidcIdEmpty = "user.oidc_id_empty"
MsgUserWeChatIdEmpty = "user.wechat_id_empty"
MsgUserTelegramIdEmpty = "user.telegram_id_empty"
MsgUserTelegramNotBound = "user.telegram_not_bound"
MsgUserLinuxDOIdEmpty = "user.linux_do_id_empty"
)
// Quota related messages
const (
MsgQuotaNegative = "quota.negative"
MsgQuotaExceedMax = "quota.exceed_max"
MsgQuotaInsufficient = "quota.insufficient"
MsgQuotaWarningInvalid = "quota.warning_invalid"
MsgQuotaThresholdGtZero = "quota.threshold_gt_zero"
)
// Subscription related messages
const (
MsgSubscriptionNotEnabled = "subscription.not_enabled"
MsgSubscriptionTitleEmpty = "subscription.title_empty"
MsgSubscriptionPriceNegative = "subscription.price_negative"
MsgSubscriptionPriceMax = "subscription.price_max"
MsgSubscriptionPurchaseLimitNeg = "subscription.purchase_limit_negative"
MsgSubscriptionQuotaNegative = "subscription.quota_negative"
MsgSubscriptionGroupNotExists = "subscription.group_not_exists"
MsgSubscriptionResetCycleGtZero = "subscription.reset_cycle_gt_zero"
MsgSubscriptionPurchaseMax = "subscription.purchase_max"
MsgSubscriptionInvalidId = "subscription.invalid_id"
MsgSubscriptionInvalidUserId = "subscription.invalid_user_id"
)
// Payment related messages
const (
MsgPaymentNotConfigured = "payment.not_configured"
MsgPaymentMethodNotExists = "payment.method_not_exists"
MsgPaymentCallbackError = "payment.callback_error"
MsgPaymentCreateFailed = "payment.create_failed"
MsgPaymentStartFailed = "payment.start_failed"
MsgPaymentAmountTooLow = "payment.amount_too_low"
MsgPaymentStripeNotConfig = "payment.stripe_not_configured"
MsgPaymentWebhookNotConfig = "payment.webhook_not_configured"
MsgPaymentPriceIdNotConfig = "payment.price_id_not_configured"
MsgPaymentCreemNotConfig = "payment.creem_not_configured"
)
// Topup related messages
const (
MsgTopupNotProvided = "topup.not_provided"
MsgTopupOrderNotExists = "topup.order_not_exists"
MsgTopupOrderStatus = "topup.order_status"
MsgTopupFailed = "topup.failed"
MsgTopupInvalidQuota = "topup.invalid_quota"
)
// Channel related messages
const (
MsgChannelNotExists = "channel.not_exists"
MsgChannelIdFormatError = "channel.id_format_error"
MsgChannelNoAvailableKey = "channel.no_available_key"
MsgChannelGetListFailed = "channel.get_list_failed"
MsgChannelGetTagsFailed = "channel.get_tags_failed"
MsgChannelGetKeyFailed = "channel.get_key_failed"
MsgChannelGetOllamaFailed = "channel.get_ollama_failed"
MsgChannelQueryFailed = "channel.query_failed"
MsgChannelNoValidUpstream = "channel.no_valid_upstream"
MsgChannelUpstreamSaturated = "channel.upstream_saturated"
MsgChannelGetAvailableFailed = "channel.get_available_failed"
)
// Model related messages
const (
MsgModelNameEmpty = "model.name_empty"
MsgModelNameExists = "model.name_exists"
MsgModelIdMissing = "model.id_missing"
MsgModelGetListFailed = "model.get_list_failed"
MsgModelGetFailed = "model.get_failed"
MsgModelResetSuccess = "model.reset_success"
)
// Vendor related messages
const (
MsgVendorNameEmpty = "vendor.name_empty"
MsgVendorNameExists = "vendor.name_exists"
MsgVendorIdMissing = "vendor.id_missing"
)
// Group related messages
const (
MsgGroupNameTypeEmpty = "group.name_type_empty"
MsgGroupNameExists = "group.name_exists"
MsgGroupIdMissing = "group.id_missing"
)
// Checkin related messages
const (
MsgCheckinDisabled = "checkin.disabled"
MsgCheckinAlreadyToday = "checkin.already_today"
MsgCheckinFailed = "checkin.failed"
MsgCheckinQuotaFailed = "checkin.quota_failed"
)
// Passkey related messages
const (
MsgPasskeyCreateFailed = "passkey.create_failed"
MsgPasskeyLoginAbnormal = "passkey.login_abnormal"
MsgPasskeyUpdateFailed = "passkey.update_failed"
MsgPasskeyInvalidUserId = "passkey.invalid_user_id"
MsgPasskeyVerifyFailed = "passkey.verify_failed"
)
// 2FA related messages
const (
MsgTwoFANotEnabled = "twofa.not_enabled"
MsgTwoFAUserIdEmpty = "twofa.user_id_empty"
MsgTwoFAAlreadyExists = "twofa.already_exists"
MsgTwoFARecordIdEmpty = "twofa.record_id_empty"
MsgTwoFACodeInvalid = "twofa.code_invalid"
)
// Rate limit related messages
const (
MsgRateLimitReached = "rate_limit.reached"
MsgRateLimitTotalReached = "rate_limit.total_reached"
)
// Setting related messages
const (
MsgSettingInvalidType = "setting.invalid_type"
MsgSettingWebhookEmpty = "setting.webhook_empty"
MsgSettingWebhookInvalid = "setting.webhook_invalid"
MsgSettingEmailInvalid = "setting.email_invalid"
MsgSettingBarkUrlEmpty = "setting.bark_url_empty"
MsgSettingBarkUrlInvalid = "setting.bark_url_invalid"
MsgSettingGotifyUrlEmpty = "setting.gotify_url_empty"
MsgSettingGotifyTokenEmpty = "setting.gotify_token_empty"
MsgSettingGotifyUrlInvalid = "setting.gotify_url_invalid"
MsgSettingUrlMustHttp = "setting.url_must_http"
MsgSettingSaved = "setting.saved"
)
// Deployment related messages (io.net)
const (
MsgDeploymentNotEnabled = "deployment.not_enabled"
MsgDeploymentIdRequired = "deployment.id_required"
MsgDeploymentContainerIdReq = "deployment.container_id_required"
MsgDeploymentNameEmpty = "deployment.name_empty"
MsgDeploymentNameTaken = "deployment.name_taken"
MsgDeploymentHardwareIdReq = "deployment.hardware_id_required"
MsgDeploymentHardwareInvId = "deployment.hardware_invalid_id"
MsgDeploymentApiKeyRequired = "deployment.api_key_required"
MsgDeploymentInvalidPayload = "deployment.invalid_payload"
MsgDeploymentNotFound = "deployment.not_found"
)
// Performance related messages
const (
MsgPerfDiskCacheCleared = "performance.disk_cache_cleared"
MsgPerfStatsReset = "performance.stats_reset"
MsgPerfGcExecuted = "performance.gc_executed"
)
// Ability related messages
const (
MsgAbilityDbCorrupted = "ability.db_corrupted"
MsgAbilityRepairRunning = "ability.repair_running"
)
// OAuth related messages
const (
MsgOAuthInvalidCode = "oauth.invalid_code"
MsgOAuthGetUserErr = "oauth.get_user_error"
MsgOAuthAccountUsed = "oauth.account_used"
MsgOAuthUnknownProvider = "oauth.unknown_provider"
MsgOAuthStateInvalid = "oauth.state_invalid"
MsgOAuthNotEnabled = "oauth.not_enabled"
MsgOAuthUserDeleted = "oauth.user_deleted"
MsgOAuthUserBanned = "oauth.user_banned"
MsgOAuthBindSuccess = "oauth.bind_success"
MsgOAuthAlreadyBound = "oauth.already_bound"
MsgOAuthConnectFailed = "oauth.connect_failed"
MsgOAuthTokenFailed = "oauth.token_failed"
MsgOAuthUserInfoEmpty = "oauth.user_info_empty"
MsgOAuthTrustLevelLow = "oauth.trust_level_low"
)
// Model layer error messages (for translation in controller)
const (
MsgRedeemFailed = "redeem.failed"
MsgCreateDefaultTokenErr = "user.create_default_token_error"
MsgUuidDuplicate = "common.uuid_duplicate"
MsgInvalidInput = "common.invalid_input"
)
// Custom OAuth provider related messages
const (
MsgCustomOAuthNotFound = "custom_oauth.not_found"
MsgCustomOAuthSlugEmpty = "custom_oauth.slug_empty"
MsgCustomOAuthSlugExists = "custom_oauth.slug_exists"
MsgCustomOAuthNameEmpty = "custom_oauth.name_empty"
MsgCustomOAuthHasBindings = "custom_oauth.has_bindings"
MsgCustomOAuthBindingNotFound = "custom_oauth.binding_not_found"
MsgCustomOAuthProviderIdInvalid = "custom_oauth.provider_id_field_invalid"
)

View File

@@ -1,251 +0,0 @@
# English translations
# Common messages
common.invalid_params: "Invalid parameters"
common.database_error: "Database error, please try again later"
common.retry_later: "Please try again later"
common.generate_failed: "Generation failed"
common.not_found: "Not found"
common.unauthorized: "Unauthorized"
common.forbidden: "Forbidden"
common.invalid_id: "Invalid ID"
common.id_empty: "ID is empty!"
common.feature_disabled: "This feature is not enabled"
common.operation_success: "Operation successful"
common.operation_failed: "Operation failed"
common.update_success: "Update successful"
common.update_failed: "Update failed"
common.create_success: "Creation successful"
common.create_failed: "Creation failed"
common.delete_success: "Deletion successful"
common.delete_failed: "Deletion failed"
common.already_exists: "Already exists"
common.name_cannot_be_empty: "Name cannot be empty"
# Token messages
token.name_too_long: "Token name is too long"
token.quota_negative: "Quota value cannot be negative"
token.quota_exceed_max: "Quota value exceeds valid range, maximum is {{.Max}}"
token.generate_failed: "Failed to generate token"
token.get_info_failed: "Failed to get token info, please try again later"
token.expired_cannot_enable: "Token has expired and cannot be enabled. Please modify the expiration time or set it to never expire"
token.exhausted_cannot_enable: "Token quota is exhausted and cannot be enabled. Please modify the remaining quota or set it to unlimited"
token.invalid: "Invalid token"
token.not_provided: "Token not provided"
token.expired: "This token has expired"
token.exhausted: "This token quota is exhausted TokenStatusExhausted[sk-{{.Prefix}}***{{.Suffix}}]"
token.status_unavailable: "This token status is unavailable"
token.db_error: "Invalid token, database query error, please contact administrator"
# Redemption messages
redemption.name_length: "Redemption code name length must be between 1-20"
redemption.count_positive: "Redemption code count must be greater than 0"
redemption.count_max: "Maximum 100 redemption codes can be generated at once"
redemption.create_failed: "Failed to create redemption code, please try again later"
redemption.invalid: "Invalid redemption code"
redemption.used: "This redemption code has been used"
redemption.expired: "This redemption code has expired"
redemption.failed: "Redemption failed, please try again later"
redemption.not_provided: "Redemption code not provided"
redemption.expire_time_invalid: "Expiration time cannot be earlier than current time"
# User messages
user.password_login_disabled: "Password login has been disabled by administrator"
user.register_disabled: "New user registration has been disabled by administrator"
user.password_register_disabled: "Password registration has been disabled by administrator, please use third-party account verification"
user.username_or_password_empty: "Username or password is empty"
user.username_or_password_error: "Username or password is incorrect, or user has been banned"
user.email_or_password_empty: "Email or password is empty!"
user.exists: "Username already exists or has been deleted"
user.not_exists: "User does not exist"
user.disabled: "This user has been disabled"
user.session_save_failed: "Failed to save session, please try again"
user.require_2fa: "Please enter two-factor authentication code"
user.email_verification_required: "Email verification is enabled, please enter email address and verification code"
user.verification_code_error: "Verification code is incorrect or has expired"
user.input_invalid: "Invalid input {{.Error}}"
user.no_permission_same_level: "No permission to access users of same or higher level"
user.no_permission_higher_level: "No permission to update users of same or higher permission level"
user.cannot_create_higher_level: "Cannot create users with permission level equal to or higher than yourself"
user.cannot_delete_root_user: "Cannot delete super administrator account"
user.cannot_disable_root_user: "Cannot disable super administrator user"
user.cannot_demote_root_user: "Cannot demote super administrator user"
user.already_admin: "This user is already an administrator"
user.already_common: "This user is already a common user"
user.admin_cannot_promote: "Regular administrators cannot promote other users to administrator"
user.original_password_error: "Original password is incorrect"
user.invite_quota_insufficient: "Invitation quota is insufficient!"
user.transfer_quota_minimum: "Minimum transfer quota is {{.Min}}!"
user.transfer_success: "Transfer successful"
user.transfer_failed: "Transfer failed {{.Error}}"
user.topup_processing: "Top-up is processing, please try again later"
user.register_failed: "User registration failed or user ID retrieval failed"
user.default_token_failed: "Failed to generate default token"
user.aff_code_empty: "Affiliate code is empty!"
user.email_empty: "Email is empty!"
user.github_id_empty: "GitHub ID is empty!"
user.discord_id_empty: "Discord ID is empty!"
user.oidc_id_empty: "OIDC ID is empty!"
user.wechat_id_empty: "WeChat ID is empty!"
user.telegram_id_empty: "Telegram ID is empty!"
user.telegram_not_bound: "This Telegram account is not bound"
user.linux_do_id_empty: "Linux DO ID is empty!"
# Quota messages
quota.negative: "Quota cannot be negative!"
quota.exceed_max: "Quota value exceeds valid range"
quota.insufficient: "Insufficient quota"
quota.warning_invalid: "Invalid warning type"
quota.threshold_gt_zero: "Warning threshold must be greater than 0"
# Subscription messages
subscription.not_enabled: "Subscription plan is not enabled"
subscription.title_empty: "Subscription plan title cannot be empty"
subscription.price_negative: "Price cannot be negative"
subscription.price_max: "Price cannot exceed 9999"
subscription.purchase_limit_negative: "Purchase limit cannot be negative"
subscription.quota_negative: "Total quota cannot be negative"
subscription.group_not_exists: "Upgrade group does not exist"
subscription.reset_cycle_gt_zero: "Custom reset cycle must be greater than 0 seconds"
subscription.purchase_max: "Purchase limit for this plan has been reached"
subscription.invalid_id: "Invalid subscription ID"
subscription.invalid_user_id: "Invalid user ID"
# Payment messages
payment.not_configured: "Payment information has not been configured by administrator"
payment.method_not_exists: "Payment method does not exist"
payment.callback_error: "Callback URL configuration error"
payment.create_failed: "Failed to create order"
payment.start_failed: "Failed to start payment"
payment.amount_too_low: "Plan amount is too low"
payment.stripe_not_configured: "Stripe is not configured or key is invalid"
payment.webhook_not_configured: "Webhook is not configured"
payment.price_id_not_configured: "StripePriceId is not configured for this plan"
payment.creem_not_configured: "CreemProductId is not configured for this plan"
# Topup messages
topup.not_provided: "Payment order number not provided"
topup.order_not_exists: "Top-up order does not exist"
topup.order_status: "Top-up order status error"
topup.failed: "Top-up failed, please try again later"
topup.invalid_quota: "Invalid top-up quota"
# Channel messages
channel.not_exists: "Channel does not exist"
channel.id_format_error: "Channel ID format error"
channel.no_available_key: "No available channel keys"
channel.get_list_failed: "Failed to get channel list, please try again later"
channel.get_tags_failed: "Failed to get tags, please try again later"
channel.get_key_failed: "Failed to get channel key"
channel.get_ollama_failed: "Failed to get Ollama models"
channel.query_failed: "Failed to query channel"
channel.no_valid_upstream: "No valid upstream channel"
channel.upstream_saturated: "Current group upstream load is saturated, please try again later"
channel.get_available_failed: "Failed to get available channels for model {{.Model}} under group {{.Group}}"
# Model messages
model.name_empty: "Model name cannot be empty"
model.name_exists: "Model name already exists"
model.id_missing: "Model ID is missing"
model.get_list_failed: "Failed to get model list, please try again later"
model.get_failed: "Failed to get upstream models"
model.reset_success: "Model ratio reset successful"
# Vendor messages
vendor.name_empty: "Vendor name cannot be empty"
vendor.name_exists: "Vendor name already exists"
vendor.id_missing: "Vendor ID is missing"
# Group messages
group.name_type_empty: "Group name and type cannot be empty"
group.name_exists: "Group name already exists"
group.id_missing: "Group ID is missing"
# Checkin messages
checkin.disabled: "Check-in feature is not enabled"
checkin.already_today: "Already checked in today"
checkin.failed: "Check-in failed, please try again later"
checkin.quota_failed: "Check-in failed: quota update error"
# Passkey messages
passkey.create_failed: "Unable to create Passkey credential"
passkey.login_abnormal: "Passkey login status is abnormal"
passkey.update_failed: "Passkey credential update failed"
passkey.invalid_user_id: "Invalid user ID"
passkey.verify_failed: "Passkey verification failed, please try again or contact administrator"
# 2FA messages
twofa.not_enabled: "User has not enabled 2FA"
twofa.user_id_empty: "User ID cannot be empty"
twofa.already_exists: "User already has 2FA configured"
twofa.record_id_empty: "2FA record ID cannot be empty"
twofa.code_invalid: "Verification code or backup code is incorrect"
# Rate limit messages
rate_limit.reached: "You have reached the request limit: maximum {{.Max}} requests in {{.Minutes}} minutes"
rate_limit.total_reached: "You have reached the total request limit: maximum {{.Max}} requests in {{.Minutes}} minutes, including failed attempts"
# Setting messages
setting.invalid_type: "Invalid warning type"
setting.webhook_empty: "Webhook URL cannot be empty"
setting.webhook_invalid: "Invalid Webhook URL"
setting.email_invalid: "Invalid email address"
setting.bark_url_empty: "Bark push URL cannot be empty"
setting.bark_url_invalid: "Invalid Bark push URL"
setting.gotify_url_empty: "Gotify server URL cannot be empty"
setting.gotify_token_empty: "Gotify token cannot be empty"
setting.gotify_url_invalid: "Invalid Gotify server URL"
setting.url_must_http: "URL must start with http:// or https://"
setting.saved: "Settings updated"
# Deployment messages (io.net)
deployment.not_enabled: "io.net model deployment is not enabled or API key is missing"
deployment.id_required: "Deployment ID is required"
deployment.container_id_required: "Container ID is required"
deployment.name_empty: "Deployment name cannot be empty"
deployment.name_taken: "Deployment name is not available, please choose a different name"
deployment.hardware_id_required: "hardware_id parameter is required"
deployment.hardware_invalid_id: "Invalid hardware_id parameter"
deployment.api_key_required: "api_key is required"
deployment.invalid_payload: "Invalid request payload"
deployment.not_found: "Container details not found"
# Performance messages
performance.disk_cache_cleared: "Inactive disk cache has been cleared"
performance.stats_reset: "Statistics have been reset"
performance.gc_executed: "GC has been executed"
# Ability messages
ability.db_corrupted: "Database consistency has been compromised"
ability.repair_running: "A repair task is already running, please try again later"
# OAuth messages
oauth.invalid_code: "Invalid authorization code"
oauth.get_user_error: "Failed to get user information"
oauth.account_used: "This account has been bound to another user"
oauth.unknown_provider: "Unknown OAuth provider"
oauth.state_invalid: "State parameter is empty or mismatched"
oauth.not_enabled: "{{.Provider}} login and registration has not been enabled by administrator"
oauth.user_deleted: "User has been deleted"
oauth.user_banned: "User has been banned"
oauth.bind_success: "Binding successful"
oauth.already_bound: "This {{.Provider}} account has already been bound"
oauth.connect_failed: "Unable to connect to {{.Provider}} server, please try again later"
oauth.token_failed: "Failed to get token from {{.Provider}}, please check settings"
oauth.user_info_empty: "{{.Provider}} returned empty user info, please check settings"
oauth.trust_level_low: "Linux DO trust level does not meet the minimum required by administrator"
# Model layer error messages
redeem.failed: "Redemption failed, please try again later"
user.create_default_token_error: "Failed to create default token"
common.uuid_duplicate: "Please retry, the system generated a duplicate UUID!"
common.invalid_input: "Invalid input"
# Custom OAuth provider messages
custom_oauth.not_found: "Custom OAuth provider not found"
custom_oauth.slug_empty: "Slug cannot be empty"
custom_oauth.slug_exists: "Slug already exists"
custom_oauth.name_empty: "Provider name cannot be empty"
custom_oauth.has_bindings: "Cannot delete provider with existing user bindings"
custom_oauth.binding_not_found: "OAuth binding not found"
custom_oauth.provider_id_field_invalid: "Could not extract user ID from provider response"

View File

@@ -1,252 +0,0 @@
# Chinese (Simplified) translations
# 中文(简体)翻译文件
# Common messages
common.invalid_params: "无效的参数"
common.database_error: "数据库错误,请稍后重试"
common.retry_later: "请稍后重试"
common.generate_failed: "生成失败"
common.not_found: "未找到"
common.unauthorized: "未授权"
common.forbidden: "无权限"
common.invalid_id: "无效的ID"
common.id_empty: "ID 为空!"
common.feature_disabled: "该功能未启用"
common.operation_success: "操作成功"
common.operation_failed: "操作失败"
common.update_success: "更新成功"
common.update_failed: "更新失败"
common.create_success: "创建成功"
common.create_failed: "创建失败"
common.delete_success: "删除成功"
common.delete_failed: "删除失败"
common.already_exists: "已存在"
common.name_cannot_be_empty: "名称不能为空"
# Token messages
token.name_too_long: "令牌名称过长"
token.quota_negative: "额度值不能为负数"
token.quota_exceed_max: "额度值超出有效范围,最大值为 {{.Max}}"
token.generate_failed: "生成令牌失败"
token.get_info_failed: "获取令牌信息失败,请稍后重试"
token.expired_cannot_enable: "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期"
token.exhausted_cannot_enable: "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度"
token.invalid: "无效的令牌"
token.not_provided: "未提供令牌"
token.expired: "该令牌已过期"
token.exhausted: "该令牌额度已用尽 TokenStatusExhausted[sk-{{.Prefix}}***{{.Suffix}}]"
token.status_unavailable: "该令牌状态不可用"
token.db_error: "无效的令牌,数据库查询出错,请联系管理员"
# Redemption messages
redemption.name_length: "兑换码名称长度必须在1-20之间"
redemption.count_positive: "兑换码个数必须大于0"
redemption.count_max: "一次兑换码批量生成的个数不能大于 100"
redemption.create_failed: "创建兑换码失败,请稍后重试"
redemption.invalid: "无效的兑换码"
redemption.used: "该兑换码已被使用"
redemption.expired: "该兑换码已过期"
redemption.failed: "兑换失败,请稍后重试"
redemption.not_provided: "未提供兑换码"
redemption.expire_time_invalid: "过期时间不能早于当前时间"
# User messages
user.password_login_disabled: "管理员关闭了密码登录"
user.register_disabled: "管理员关闭了新用户注册"
user.password_register_disabled: "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册"
user.username_or_password_empty: "用户名或密码为空"
user.username_or_password_error: "用户名或密码错误,或用户已被封禁"
user.email_or_password_empty: "邮箱地址或密码为空!"
user.exists: "用户名已存在,或已注销"
user.not_exists: "用户不存在"
user.disabled: "该用户已被禁用"
user.session_save_failed: "无法保存会话信息,请重试"
user.require_2fa: "请输入两步验证码"
user.email_verification_required: "管理员开启了邮箱验证,请输入邮箱地址和验证码"
user.verification_code_error: "验证码错误或已过期"
user.input_invalid: "输入不合法 {{.Error}}"
user.no_permission_same_level: "无权获取同级或更高等级用户的信息"
user.no_permission_higher_level: "无权更新同权限等级或更高权限等级的用户信息"
user.cannot_create_higher_level: "无法创建权限大于等于自己的用户"
user.cannot_delete_root_user: "不能删除超级管理员账户"
user.cannot_disable_root_user: "无法禁用超级管理员用户"
user.cannot_demote_root_user: "无法降级超级管理员用户"
user.already_admin: "该用户已经是管理员"
user.already_common: "该用户已经是普通用户"
user.admin_cannot_promote: "普通管理员用户无法提升其他用户为管理员"
user.original_password_error: "原密码错误"
user.invite_quota_insufficient: "邀请额度不足!"
user.transfer_quota_minimum: "转移额度最小为{{.Min}}"
user.transfer_success: "划转成功"
user.transfer_failed: "划转失败 {{.Error}}"
user.topup_processing: "充值处理中,请稍后重试"
user.register_failed: "用户注册失败或用户ID获取失败"
user.default_token_failed: "生成默认令牌失败"
user.aff_code_empty: "affCode 为空!"
user.email_empty: "email 为空!"
user.github_id_empty: "GitHub id 为空!"
user.discord_id_empty: "discord id 为空!"
user.oidc_id_empty: "oidc id 为空!"
user.wechat_id_empty: "WeChat id 为空!"
user.telegram_id_empty: "Telegram id 为空!"
user.telegram_not_bound: "该 Telegram 账户未绑定"
user.linux_do_id_empty: "Linux DO id 为空!"
# Quota messages
quota.negative: "额度不能为负数!"
quota.exceed_max: "额度值超出有效范围"
quota.insufficient: "额度不足"
quota.warning_invalid: "无效的预警类型"
quota.threshold_gt_zero: "预警阈值必须大于0"
# Subscription messages
subscription.not_enabled: "套餐未启用"
subscription.title_empty: "套餐标题不能为空"
subscription.price_negative: "价格不能为负数"
subscription.price_max: "价格不能超过9999"
subscription.purchase_limit_negative: "购买上限不能为负数"
subscription.quota_negative: "总额度不能为负数"
subscription.group_not_exists: "升级分组不存在"
subscription.reset_cycle_gt_zero: "自定义重置周期需大于0秒"
subscription.purchase_max: "已达到该套餐购买上限"
subscription.invalid_id: "无效的订阅ID"
subscription.invalid_user_id: "无效的用户ID"
# Payment messages
payment.not_configured: "当前管理员未配置支付信息"
payment.method_not_exists: "支付方式不存在"
payment.callback_error: "回调地址配置错误"
payment.create_failed: "创建订单失败"
payment.start_failed: "拉起支付失败"
payment.amount_too_low: "套餐金额过低"
payment.stripe_not_configured: "Stripe 未配置或密钥无效"
payment.webhook_not_configured: "Webhook 未配置"
payment.price_id_not_configured: "该套餐未配置 StripePriceId"
payment.creem_not_configured: "该套餐未配置 CreemProductId"
# Topup messages
topup.not_provided: "未提供支付单号"
topup.order_not_exists: "充值订单不存在"
topup.order_status: "充值订单状态错误"
topup.failed: "充值失败,请稍后重试"
topup.invalid_quota: "无效的充值额度"
# Channel messages
channel.not_exists: "渠道不存在"
channel.id_format_error: "渠道ID格式错误"
channel.no_available_key: "没有可用的渠道密钥"
channel.get_list_failed: "获取渠道列表失败,请稍后重试"
channel.get_tags_failed: "获取标签失败,请稍后重试"
channel.get_key_failed: "获取渠道密钥失败"
channel.get_ollama_failed: "获取Ollama模型失败"
channel.query_failed: "查询渠道失败"
channel.no_valid_upstream: "无有效上游渠道"
channel.upstream_saturated: "当前分组上游负载已饱和,请稍后再试"
channel.get_available_failed: "获取分组 {{.Group}} 下模型 {{.Model}} 的可用渠道失败"
# Model messages
model.name_empty: "模型名称不能为空"
model.name_exists: "模型名称已存在"
model.id_missing: "缺少模型 ID"
model.get_list_failed: "获取模型列表失败,请稍后重试"
model.get_failed: "获取上游模型失败"
model.reset_success: "重置模型倍率成功"
# Vendor messages
vendor.name_empty: "供应商名称不能为空"
vendor.name_exists: "供应商名称已存在"
vendor.id_missing: "缺少供应商 ID"
# Group messages
group.name_type_empty: "组名称和类型不能为空"
group.name_exists: "组名称已存在"
group.id_missing: "缺少组 ID"
# Checkin messages
checkin.disabled: "签到功能未启用"
checkin.already_today: "今日已签到"
checkin.failed: "签到失败,请稍后重试"
checkin.quota_failed: "签到失败:更新额度出错"
# Passkey messages
passkey.create_failed: "无法创建 Passkey 凭证"
passkey.login_abnormal: "Passkey 登录状态异常"
passkey.update_failed: "Passkey 凭证更新失败"
passkey.invalid_user_id: "无效的用户 ID"
passkey.verify_failed: "Passkey 验证失败,请重试或联系管理员"
# 2FA messages
twofa.not_enabled: "用户未启用2FA"
twofa.user_id_empty: "用户ID不能为空"
twofa.already_exists: "用户已存在2FA设置"
twofa.record_id_empty: "2FA记录ID不能为空"
twofa.code_invalid: "验证码或备用码不正确"
# Rate limit messages
rate_limit.reached: "您已达到请求数限制:{{.Minutes}}分钟内最多请求{{.Max}}次"
rate_limit.total_reached: "您已达到总请求数限制:{{.Minutes}}分钟内最多请求{{.Max}}次,包括失败次数"
# Setting messages
setting.invalid_type: "无效的预警类型"
setting.webhook_empty: "Webhook地址不能为空"
setting.webhook_invalid: "无效的Webhook地址"
setting.email_invalid: "无效的邮箱地址"
setting.bark_url_empty: "Bark推送URL不能为空"
setting.bark_url_invalid: "无效的Bark推送URL"
setting.gotify_url_empty: "Gotify服务器地址不能为空"
setting.gotify_token_empty: "Gotify令牌不能为空"
setting.gotify_url_invalid: "无效的Gotify服务器地址"
setting.url_must_http: "URL必须以http://或https://开头"
setting.saved: "设置已更新"
# Deployment messages (io.net)
deployment.not_enabled: "io.net 模型部署功能未启用或 API 密钥缺失"
deployment.id_required: "deployment ID 为必填项"
deployment.container_id_required: "container ID 为必填项"
deployment.name_empty: "deployment 名称不能为空"
deployment.name_taken: "deployment 名称已被使用,请选择其他名称"
deployment.hardware_id_required: "hardware_id 参数为必填项"
deployment.hardware_invalid_id: "无效的 hardware_id 参数"
deployment.api_key_required: "api_key 为必填项"
deployment.invalid_payload: "无效的请求内容"
deployment.not_found: "未找到容器详情"
# Performance messages
performance.disk_cache_cleared: "不活跃的磁盘缓存已清理"
performance.stats_reset: "统计信息已重置"
performance.gc_executed: "GC 已执行"
# Ability messages
ability.db_corrupted: "数据库一致性被破坏"
ability.repair_running: "已经有一个修复任务在运行中,请稍后再试"
# OAuth messages
oauth.invalid_code: "无效的授权码"
oauth.get_user_error: "获取用户信息失败"
oauth.account_used: "该账户已被其他用户绑定"
oauth.unknown_provider: "未知的 OAuth 提供商"
oauth.state_invalid: "state 参数为空或不匹配"
oauth.not_enabled: "管理员未开启通过 {{.Provider}} 登录以及注册"
oauth.user_deleted: "用户已注销"
oauth.user_banned: "用户已被封禁"
oauth.bind_success: "绑定成功"
oauth.already_bound: "该 {{.Provider}} 账户已被绑定"
oauth.connect_failed: "无法连接至 {{.Provider}} 服务器,请稍后重试"
oauth.token_failed: "{{.Provider}} 获取 Token 失败,请检查设置"
oauth.user_info_empty: "{{.Provider}} 获取用户信息为空,请检查设置"
oauth.trust_level_low: "Linux DO 信任等级未达到管理员设置的最低信任等级"
# Model layer error messages
redeem.failed: "兑换失败,请稍后重试"
user.create_default_token_error: "创建默认令牌失败"
common.uuid_duplicate: "请重试,系统生成的 UUID 竟然重复了!"
common.invalid_input: "输入不合法"
# Custom OAuth provider messages
custom_oauth.not_found: "自定义 OAuth 提供商不存在"
custom_oauth.slug_empty: "标识符不能为空"
custom_oauth.slug_exists: "标识符已存在"
custom_oauth.name_empty: "提供商名称不能为空"
custom_oauth.has_bindings: "无法删除已有用户绑定的提供商"
custom_oauth.binding_not_found: "OAuth 绑定不存在"
custom_oauth.provider_id_field_invalid: "无法从提供商响应中提取用户 ID"

25
main.go
View File

@@ -14,11 +14,9 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/controller"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/middleware"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/oauth"
"github.com/QuantumNous/new-api/router"
"github.com/QuantumNous/new-api/service"
_ "github.com/QuantumNous/new-api/setting/performance_setting"
@@ -153,7 +151,6 @@ func main() {
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.RequestId())
server.Use(middleware.PoweredBy())
server.Use(middleware.I18n())
middleware.SetUpLogger(server)
// Initialize session store
store := cookie.NewStore([]byte(common.SessionSecret))
@@ -277,27 +274,5 @@ func InitResources() error {
if err != nil {
return err
}
// 启动系统监控
common.StartSystemMonitor()
// Initialize i18n
err = i18n.Init()
if err != nil {
common.SysError("failed to initialize i18n: " + err.Error())
// Don't return error, i18n is not critical
} else {
common.SysLog("i18n initialized with languages: " + strings.Join(i18n.SupportedLanguages(), ", "))
}
// Register user language loader for lazy loading
i18n.SetUserLangLoader(model.GetUserLanguage)
// Load custom OAuth providers from database
err = oauth.LoadCustomProviders()
if err != nil {
common.SysError("failed to load custom OAuth providers: " + err.Error())
// Don't return error, custom OAuth is not critical
}
return nil
}

View File

@@ -132,6 +132,17 @@ func authHelper(c *gin.Context, minRole int) {
c.Set("user_group", session.Get("group"))
c.Set("use_access_token", useAccessToken)
//userCache, err := model.GetUserCache(id.(int))
//if err != nil {
// c.JSON(http.StatusOK, gin.H{
// "success": false,
// "message": err.Error(),
// })
// c.Abort()
// return
//}
//userCache.WriteContext(c)
c.Next()
}
@@ -168,63 +179,6 @@ func WssAuth(c *gin.Context) {
}
// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
// 即使令牌已过期、已耗尽或已禁用,也允许访问。
// 仍然检查用户是否被封禁。
func TokenAuthReadOnly() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
if key == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "未提供 Authorization 请求头",
})
c.Abort()
return
}
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.GetTokenByKey(key, false)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无效的令牌",
})
c.Abort()
return
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
c.Abort()
return
}
if userCache.Status != common.UserStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "用户已被封禁",
})
c.Abort()
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
c.Next()
}
}
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
// 先检测是否为ws

View File

@@ -2,7 +2,6 @@ package middleware
import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
@@ -15,8 +14,5 @@ func BodyStorageCleanup() gin.HandlerFunc {
// 请求结束后清理存储
common.CleanupBodyStorage(c)
// 清理文件缓存URL 下载的文件等)
service.CleanupFileSources(c)
}
}

View File

@@ -1,50 +0,0 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/i18n"
)
// I18n middleware detects and sets the language preference for the request
func I18n() gin.HandlerFunc {
return func(c *gin.Context) {
lang := detectLanguage(c)
c.Set(string(constant.ContextKeyLanguage), lang)
c.Next()
}
}
// detectLanguage determines the language preference for the request
// Priority: 1. User setting (if logged in) -> 2. Accept-Language header -> 3. Default language
func detectLanguage(c *gin.Context) string {
// 1. Try to get language from user setting (set by auth middleware)
if userSetting, ok := common.GetContextKeyType[dto.UserSetting](c, constant.ContextKeyUserSetting); ok {
if userSetting.Language != "" && i18n.IsSupported(userSetting.Language) {
return userSetting.Language
}
}
// 2. Parse Accept-Language header
acceptLang := c.GetHeader("Accept-Language")
if acceptLang != "" {
lang := i18n.ParseAcceptLanguage(acceptLang)
if i18n.IsSupported(lang) {
return lang
}
}
// 3. Return default language
return i18n.DefaultLang
}
// GetLanguage returns the current language from gin context
func GetLanguage(c *gin.Context) string {
if lang := c.GetString(string(constant.ContextKeyLanguage)); lang != "" {
return lang
}
return i18n.DefaultLang
}

View File

@@ -1,65 +0,0 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// SystemPerformanceCheck 检查系统性能中间件
func SystemPerformanceCheck() gin.HandlerFunc {
return func(c *gin.Context) {
// 仅检查 Relay 接口 (/v1, /v1beta 等)
// 这里简单判断路径前缀,可以根据实际路由调整
path := c.Request.URL.Path
if strings.HasPrefix(path, "/v1/messages") {
if err := checkSystemPerformance(); err != nil {
c.JSON(err.StatusCode, gin.H{
"error": err.ToClaudeError(),
})
c.Abort()
return
}
} else {
if err := checkSystemPerformance(); err != nil {
c.JSON(err.StatusCode, gin.H{
"error": err.ToOpenAIError(),
})
c.Abort()
return
}
}
c.Next()
}
}
// checkSystemPerformance 检查系统性能是否超过阈值
func checkSystemPerformance() *types.NewAPIError {
config := common.GetPerformanceMonitorConfig()
if !config.Enabled {
return nil
}
status := common.GetSystemStatus()
// 检查 CPU
if config.CPUThreshold > 0 && int(status.CPUUsage) > config.CPUThreshold {
return types.NewErrorWithStatusCode(errors.New("system cpu overloaded"), "system_cpu_overloaded", http.StatusServiceUnavailable)
}
// 检查内存
if config.MemoryThreshold > 0 && int(status.MemoryUsage) > config.MemoryThreshold {
return types.NewErrorWithStatusCode(errors.New("system memory overloaded"), "system_memory_overloaded", http.StatusServiceUnavailable)
}
// 检查磁盘
if config.DiskThreshold > 0 && int(status.DiskUsage) > config.DiskThreshold {
return types.NewErrorWithStatusCode(errors.New("system disk overloaded"), "system_disk_overloaded", http.StatusServiceUnavailable)
}
return nil
}

View File

@@ -115,88 +115,3 @@ func DownloadRateLimit() func(c *gin.Context) {
func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
}
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
// instead of client IP, making it resistant to proxy rotation attacks.
// Must be used AFTER authentication middleware (UserAuth).
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
if common.RedisEnabled {
return func(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
userRedisRateLimiter(c, maxRequestNum, duration, key)
}
}
// It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
key := fmt.Sprintf("%s:user:%d", mark, userId)
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
}
}
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
// (to support user-ID-based keys).
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
ctx := context.Background()
rdb := common.RDB
listLength, err := rdb.LLen(ctx, key).Result()
if err != nil {
fmt.Println(err.Error())
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
fmt.Println(err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
fmt.Println(err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
}
}
// SearchRateLimit returns a per-user rate limiter for search endpoints.
// 10 requests per 60 seconds per user (by user ID, not IP).
func SearchRateLimit() func(c *gin.Context) {
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
}

View File

@@ -1,158 +0,0 @@
package model
import (
"errors"
"strings"
"time"
)
// CustomOAuthProvider stores configuration for custom OAuth providers
type CustomOAuthProvider struct {
Id int `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
// Field mapping configuration (supports JSONPath via gjson)
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
// Advanced options
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (CustomOAuthProvider) TableName() string {
return "custom_oauth_providers"
}
// GetAllCustomOAuthProviders returns all custom OAuth providers
func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Order("id asc").Find(&providers).Error
return providers, err
}
// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
return providers, err
}
// GetCustomOAuthProviderById returns a custom OAuth provider by ID
func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.First(&provider, id).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.Where("slug = ?", slug).First(&provider).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Create(provider).Error
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Save(provider).Error
}
// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
func DeleteCustomOAuthProvider(id int) error {
// First, delete all user bindings for this provider
if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
return err
}
return DB.Delete(&CustomOAuthProvider{}, id).Error
}
// IsSlugTaken checks if a slug is already taken by another provider
func IsSlugTaken(slug string, excludeId int) bool {
var count int64
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
if excludeId > 0 {
query = query.Where("id != ?", excludeId)
}
query.Count(&count)
return count > 0
}
// validateCustomOAuthProvider validates a custom OAuth provider configuration
func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if provider.Name == "" {
return errors.New("provider name is required")
}
if provider.Slug == "" {
return errors.New("provider slug is required")
}
// Slug must be lowercase and contain only alphanumeric characters and hyphens
slug := strings.ToLower(provider.Slug)
for _, c := range slug {
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
}
}
provider.Slug = slug
if provider.ClientId == "" {
return errors.New("client ID is required")
}
if provider.AuthorizationEndpoint == "" {
return errors.New("authorization endpoint is required")
}
if provider.TokenEndpoint == "" {
return errors.New("token endpoint is required")
}
if provider.UserInfoEndpoint == "" {
return errors.New("user info endpoint is required")
}
// Set defaults for field mappings if empty
if provider.UserIdField == "" {
provider.UserIdField = "sub"
}
if provider.UsernameField == "" {
provider.UsernameField = "preferred_username"
}
if provider.DisplayNameField == "" {
provider.DisplayNameField = "name"
}
if provider.EmailField == "" {
provider.EmailField = "email"
}
if provider.Scopes == "" {
provider.Scopes = "openid profile email"
}
return nil
}

View File

@@ -2,8 +2,9 @@ package model
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
@@ -17,8 +18,8 @@ import (
)
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1;index:idx_user_id_id,priority:2"`
UserId int `json:"user_id" gorm:"index;index:idx_user_id_id,priority:1"`
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"`
@@ -35,7 +36,6 @@ type Log struct {
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Ip string `json:"ip" gorm:"index;default:''"`
RequestId string `json:"request_id,omitempty" gorm:"type:varchar(64);index:idx_logs_request_id;default:''"`
Other string `json:"other"`
}
@@ -50,7 +50,7 @@ const (
LogTypeRefund = 6
)
func formatUserLogs(logs []*Log, startIdx int) {
func formatUserLogs(logs []*Log) {
for i := range logs {
logs[i].ChannelName = ""
var otherMap map[string]interface{}
@@ -58,16 +58,25 @@ func formatUserLogs(logs []*Log, startIdx int) {
if otherMap != nil {
// Remove admin-only debug fields.
delete(otherMap, "admin_info")
delete(otherMap, "request_conversion")
delete(otherMap, "reject_reason")
}
logs[i].Other = common.MapToJsonStr(otherMap)
logs[i].Id = startIdx + i + 1
logs[i].Id = logs[i].Id % 1024
}
}
func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
formatUserLogs(logs, 0)
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
} else {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
}
formatUserLogs(logs)
return logs, err
}
@@ -93,7 +102,6 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
isStream bool, group string, other map[string]interface{}) {
logger.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
otherStr := common.MapToJsonStr(other)
// 判断是否需要记录 IP
needRecordIp := false
@@ -124,8 +132,7 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
}
return ""
}(),
RequestId: requestId,
Other: otherStr,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -154,7 +161,6 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
username := c.GetString("username")
requestId := c.GetString(common.RequestIdKey)
otherStr := common.MapToJsonStr(params.Other)
// 判断是否需要记录 IP
needRecordIp := false
@@ -185,8 +191,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
return ""
}(),
RequestId: requestId,
Other: otherStr,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
if err != nil {
@@ -199,7 +204,7 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams)
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string, requestId string) (logs []*Log, total int64, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB
@@ -216,9 +221,6 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
}
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
if startTimestamp != 0 {
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
@@ -267,9 +269,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return logs, total, err
}
const logSearchCountLimit = 10000
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB.Where("logs.user_id = ?", userId)
@@ -278,18 +278,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
if modelName != "" {
modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return nil, 0, err
}
tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
tx = tx.Where("logs.model_name like ?", modelName)
}
if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName)
}
if requestId != "" {
tx = tx.Where("logs.request_id = ?", requestId)
}
if startTimestamp != 0 {
tx = tx.Where("logs.created_at >= ?", startTimestamp)
}
@@ -299,28 +292,37 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if group != "" {
tx = tx.Where("logs."+logGroupCol+" = ?", group)
}
err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
common.SysError("failed to count user logs: " + err.Error())
return nil, 0, errors.New("查询日志失败")
return nil, 0, err
}
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil {
common.SysError("failed to search user logs: " + err.Error())
return nil, 0, errors.New("查询日志失败")
return nil, 0, err
}
formatUserLogs(logs, startIdx)
formatUserLogs(logs)
return logs, total, err
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
formatUserLogs(logs)
return logs, err
}
type Stat struct {
Quota int `json:"quota"`
Rpm int `json:"rpm"`
Tpm int `json:"tpm"`
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
@@ -341,12 +343,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp)
}
if modelName != "" {
modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return stat, err
}
tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
tx = tx.Where("model_name like ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
@@ -364,16 +362,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
if err := tx.Scan(&stat).Error; err != nil {
common.SysError("failed to query log stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
common.SysError("failed to query rpm/tpm stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat, nil
return stat
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {

View File

@@ -248,9 +248,6 @@ func InitLogDB() (err error) {
}
func migrateDB() error {
// Migrate price_amount column from float/double to decimal for existing tables
migrateSubscriptionPlanPriceAmount()
err := DB.AutoMigrate(
&Channel{},
&Token{},
@@ -271,24 +268,14 @@ func migrateDB() error {
&TwoFA{},
&TwoFABackupCode{},
&Checkin{},
&SubscriptionPlan{},
&SubscriptionOrder{},
&UserSubscription{},
&SubscriptionPreConsumeRecord{},
&CustomOAuthProvider{},
&UserOAuthBinding{},
)
if err != nil {
return err
}
if common.UsingSQLite {
if err := ensureSubscriptionPlanTableSQLite(); err != nil {
return err
}
} else {
if err := DB.AutoMigrate(&SubscriptionPlan{}); err != nil {
return err
}
}
return nil
}
@@ -319,11 +306,10 @@ func migrateDBFast() error {
{&TwoFA{}, "TwoFA"},
{&TwoFABackupCode{}, "TwoFABackupCode"},
{&Checkin{}, "Checkin"},
{&SubscriptionPlan{}, "SubscriptionPlan"},
{&SubscriptionOrder{}, "SubscriptionOrder"},
{&UserSubscription{}, "UserSubscription"},
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
{&UserOAuthBinding{}, "UserOAuthBinding"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))
@@ -348,15 +334,6 @@ func migrateDBFast() error {
return err
}
}
if common.UsingSQLite {
if err := ensureSubscriptionPlanTableSQLite(); err != nil {
return err
}
} else {
if err := DB.AutoMigrate(&SubscriptionPlan{}); err != nil {
return err
}
}
common.SysLog("database migrated")
return nil
}
@@ -369,139 +346,6 @@ func migrateLOGDB() error {
return nil
}
type sqliteColumnDef struct {
Name string
DDL string
}
func ensureSubscriptionPlanTableSQLite() error {
if !common.UsingSQLite {
return nil
}
tableName := "subscription_plans"
if !DB.Migrator().HasTable(tableName) {
createSQL := `CREATE TABLE ` + "`" + tableName + "`" + ` (
` + "`id`" + ` integer,
` + "`title`" + ` varchar(128) NOT NULL,
` + "`subtitle`" + ` varchar(255) DEFAULT '',
` + "`price_amount`" + ` decimal(10,6) NOT NULL,
` + "`currency`" + ` varchar(8) NOT NULL DEFAULT 'USD',
` + "`duration_unit`" + ` varchar(16) NOT NULL DEFAULT 'month',
` + "`duration_value`" + ` integer NOT NULL DEFAULT 1,
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
` + "`enabled`" + ` numeric DEFAULT 1,
` + "`sort_order`" + ` integer DEFAULT 0,
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
` + "`quota_reset_period`" + ` varchar(16) DEFAULT 'never',
` + "`quota_reset_custom_seconds`" + ` bigint DEFAULT 0,
` + "`created_at`" + ` bigint,
` + "`updated_at`" + ` bigint,
PRIMARY KEY (` + "`id`" + `)
)`
return DB.Exec(createSQL).Error
}
var cols []struct {
Name string `gorm:"column:name"`
}
if err := DB.Raw("PRAGMA table_info(`" + tableName + "`)").Scan(&cols).Error; err != nil {
return err
}
existing := make(map[string]struct{}, len(cols))
for _, c := range cols {
existing[c.Name] = struct{}{}
}
required := []sqliteColumnDef{
{Name: "title", DDL: "`title` varchar(128) NOT NULL"},
{Name: "subtitle", DDL: "`subtitle` varchar(255) DEFAULT ''"},
{Name: "price_amount", DDL: "`price_amount` decimal(10,6) NOT NULL"},
{Name: "currency", DDL: "`currency` varchar(8) NOT NULL DEFAULT 'USD'"},
{Name: "duration_unit", DDL: "`duration_unit` varchar(16) NOT NULL DEFAULT 'month'"},
{Name: "duration_value", DDL: "`duration_value` integer NOT NULL DEFAULT 1"},
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
{Name: "quota_reset_period", DDL: "`quota_reset_period` varchar(16) DEFAULT 'never'"},
{Name: "quota_reset_custom_seconds", DDL: "`quota_reset_custom_seconds` bigint DEFAULT 0"},
{Name: "created_at", DDL: "`created_at` bigint"},
{Name: "updated_at", DDL: "`updated_at` bigint"},
}
for _, col := range required {
if _, ok := existing[col.Name]; ok {
continue
}
if err := DB.Exec("ALTER TABLE `" + tableName + "` ADD COLUMN " + col.DDL).Error; err != nil {
return err
}
}
return nil
}
// migrateSubscriptionPlanPriceAmount migrates price_amount column from float/double to decimal(10,6)
// This is safe to run multiple times - it checks the column type first
func migrateSubscriptionPlanPriceAmount() {
// SQLite doesn't support ALTER COLUMN, and its type affinity handles this automatically
// Skip early to avoid GORM parsing the existing table DDL which may cause issues
if common.UsingSQLite {
return
}
tableName := "subscription_plans"
columnName := "price_amount"
// Check if table exists first
if !DB.Migrator().HasTable(tableName) {
return
}
// Check if column exists
if !DB.Migrator().HasColumn(&SubscriptionPlan{}, columnName) {
return
}
var alterSQL string
if common.UsingPostgreSQL {
// PostgreSQL: Check if already decimal/numeric
var dataType string
DB.Raw(`SELECT data_type FROM information_schema.columns
WHERE table_name = ? AND column_name = ?`, tableName, columnName).Scan(&dataType)
if dataType == "numeric" {
return // Already decimal/numeric
}
alterSQL = fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s TYPE decimal(10,6) USING %s::decimal(10,6)`,
tableName, columnName, columnName)
} else if common.UsingMySQL {
// MySQL: Check if already decimal
var columnType string
DB.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
tableName, columnName).Scan(&columnType)
if strings.HasPrefix(strings.ToLower(columnType), "decimal") {
return // Already decimal
}
alterSQL = fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s decimal(10,6) NOT NULL DEFAULT 0",
tableName, columnName)
} else {
return
}
if alterSQL != "" {
if err := DB.Exec(alterSQL).Error; err != nil {
common.SysLog(fmt.Sprintf("Warning: failed to migrate %s.%s to decimal: %v", tableName, columnName, err))
} else {
common.SysLog(fmt.Sprintf("Successfully migrated %s.%s to decimal(10,6)", tableName, columnName))
}
}
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {

View File

@@ -11,9 +11,6 @@ import (
"gorm.io/gorm"
)
// ErrRedeemFailed is returned when redemption fails due to database error
var ErrRedeemFailed = errors.New("redeem.failed")
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
@@ -151,8 +148,7 @@ func Redeem(key string, userId int) (quota int, err error) {
return err
})
if err != nil {
common.SysError("redemption failed: " + err.Error())
return 0, ErrRedeemFailed
return 0, errors.New("兑换失败," + err.Error())
}
RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s兑换码ID %d", logger.LogQuota(redemption.Quota), redemption.Id))
return redemption.Quota, nil

View File

@@ -149,7 +149,7 @@ type SubscriptionPlan struct {
Subtitle string `json:"subtitle" gorm:"type:varchar(255);default:''"`
// Display money amount (follow existing code style: float64 for money)
PriceAmount float64 `json:"price_amount" gorm:"type:decimal(10,6);not null;default:0"`
PriceAmount float64 `json:"price_amount" gorm:"type:double;not null;default:0"`
Currency string `json:"currency" gorm:"type:varchar(8);not null;default:'USD'"`
DurationUnit string `json:"duration_unit" gorm:"type:varchar(16);not null;default:'month'"`
@@ -666,22 +666,6 @@ func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
return buildSubscriptionSummaries(subs), nil
}
// HasActiveUserSubscription returns whether the user has any active subscription.
// This is a lightweight existence check to avoid heavy pre-consume transactions.
func HasActiveUserSubscription(userId int) (bool, error) {
if userId <= 0 {
return false, errors.New("invalid userId")
}
now := common.GetTimestamp()
var count int64
if err := DB.Model(&UserSubscription{}).
Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
// GetAllUserSubscriptions returns all subscriptions (active and expired) for a user.
func GetAllUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
if userId <= 0 {

View File

@@ -57,7 +57,6 @@ type Task struct {
FinishTime int64 `json:"finish_time" gorm:"index"`
Progress string `json:"progress" gorm:"type:varchar(20);index"`
Properties Properties `json:"properties" gorm:"type:json"`
Username string `json:"username,omitempty" gorm:"-"`
// 禁止返回给用户内部可能包含key等隐私信息
PrivateData TaskPrivateData `json:"-" gorm:"column:private_data;type:json"`
Data json.RawMessage `json:"data" gorm:"type:json"`
@@ -234,12 +233,6 @@ func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*
return nil
}
for _, task := range tasks {
if cache, err := GetUserCache(task.UserId); err == nil {
task.Username = cache.Username
}
}
return tasks
}

View File

@@ -6,7 +6,6 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -64,104 +63,12 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
return tokens, err
}
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
// 规则:
// 1. 转义 ! 和 _使用 ! 作为 ESCAPE 字符,兼容 MySQL/PostgreSQL/SQLite
// 2. 连续的 % 合并为单个 %
// 3. 最多允许 2 个 %
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
// 5. 不含 % 时按精确匹配
func sanitizeLikePattern(input string) (string, error) {
// 1. 先转义 ESCAPE 字符 ! 自身,再转义 _
// 使用 ! 而非 \ 作为 ESCAPE 字符,避免 MySQL 中反斜杠的字符串转义问题
input = strings.ReplaceAll(input, "!", "!!")
input = strings.ReplaceAll(input, `_`, `!_`)
// 2. 连续的 % 直接拒绝
if strings.Contains(input, "%%") {
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
}
// 3. 统计 % 数量,不得超过 2
count := strings.Count(input, "%")
if count > 2 {
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
}
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
if count > 0 {
stripped := strings.ReplaceAll(input, "%", "")
if len(stripped) < 2 {
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
}
return input, nil
}
// 5. 无 % 时,精确全匹配
return input, nil
}
const searchHardLimit = 100
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
// model 层强制截断
if limit <= 0 || limit > searchHardLimit {
limit = searchHardLimit
}
if offset < 0 {
offset = 0
}
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
if token != "" {
token = strings.Trim(token, "sk-")
}
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
maxTokens := operation_setting.GetMaxUserTokens()
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
if hasFuzzy {
count, err := CountUserTokens(userId)
if err != nil {
common.SysLog("failed to count user tokens: " + err.Error())
return nil, 0, errors.New("获取令牌数量失败")
}
if int(count) > maxTokens {
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
}
}
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
if keyword != "" {
keywordPattern, err := sanitizeLikePattern(keyword)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '!'", keywordPattern)
}
if token != "" {
tokenPattern, err := sanitizeLikePattern(token)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '!'", tokenPattern)
}
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT
err = baseQuery.Limit(maxTokens).Count(&total).Error
if err != nil {
common.SysError("failed to count search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
// 再分页查数据
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
if err != nil {
common.SysError("failed to search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
return tokens, total, nil
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
func ValidateUserToken(key string) (token *Token, err error) {

View File

@@ -95,8 +95,7 @@ func Recharge(referenceId string, customerId string) (err error) {
})
if err != nil {
common.SysError("topup failed: " + err.Error())
return errors.New("充值失败,请稍后重试")
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v支付金额%d", logger.FormatQuota(int(quota)), topUp.Amount))
@@ -368,8 +367,7 @@ func RechargeCreem(referenceId string, customerEmail string, customerName string
})
if err != nil {
common.SysError("creem topup failed: " + err.Error())
return errors.New("充值失败,请稍后重试")
return errors.New("充值失败," + err.Error())
}
RecordLog(topUp.UserId, LogTypeTopup, fmt.Sprintf("使用Creem充值成功充值额度: %v支付金额%.2f", quota, topUp.Money))

View File

@@ -540,14 +540,6 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
// UpdateGitHubId updates the user's GitHub ID (used for migration from login to numeric ID)
func (user *User) UpdateGitHubId(newGitHubId string) error {
if user.Id == 0 {
return errors.New("user id is empty")
}
return DB.Model(user).Update("github_id", newGitHubId).Error
}
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("discord id 为空!")

View File

@@ -221,13 +221,3 @@ func updateUserSettingCache(userId int, setting string) error {
}
return common.RedisHSetField(getUserCacheKey(userId), "Setting", setting)
}
// GetUserLanguage returns the user's language preference from cache
// Uses the existing GetUserCache mechanism for efficiency
func GetUserLanguage(userId int) string {
userCache, err := GetUserCache(userId)
if err != nil {
return ""
}
return userCache.GetSetting().Language
}

View File

@@ -1,125 +0,0 @@
package model
import (
"errors"
"time"
)
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
type UserOAuthBinding struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"index;not null"` // User ID
ProviderId int `json:"provider_id" gorm:"index;not null"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null"` // User ID from OAuth provider
CreatedAt time.Time `json:"created_at"`
// Composite unique index to prevent duplicate bindings
// One OAuth account can only be bound to one user
}
func (UserOAuthBinding) TableName() string {
return "user_oauth_bindings"
}
// GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
var bindings []*UserOAuthBinding
err := DB.Where("user_id = ?", userId).Find(&bindings).Error
return bindings, err
}
// GetUserOAuthBinding returns a specific binding for a user and provider
func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
var binding UserOAuthBinding
err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
return nil, err
}
return &binding, nil
}
// GetUserByOAuthBinding finds a user by provider ID and provider user ID
func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
var binding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
if err != nil {
return nil, err
}
var user User
err = DB.First(&user, binding.UserId).Error
if err != nil {
return nil, err
}
return &user, nil
}
// IsProviderUserIdTaken checks if a provider user ID is already bound to any user
func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
var count int64
DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
return count > 0
}
// CreateUserOAuthBinding creates a new OAuth binding
func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken
if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return DB.Create(binding).Error
}
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
// Check if the new provider user ID is already taken by another user
var existingBinding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
if err == nil && existingBinding.UserId != userId {
return errors.New("this OAuth account is already bound to another user")
}
// Check if user already has a binding for this provider
var binding UserOAuthBinding
err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
// No existing binding, create new one
return CreateUserOAuthBinding(&UserOAuthBinding{
UserId: userId,
ProviderId: providerId,
ProviderUserId: newProviderUserId,
})
}
// Update existing binding
return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
}
// DeleteUserOAuthBinding deletes an OAuth binding
func DeleteUserOAuthBinding(userId, providerId int) error {
return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
}
// DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
func DeleteUserOAuthBindingsByUserId(userId int) error {
return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
}
// GetBindingCountByProviderId returns the number of bindings for a provider
func GetBindingCountByProviderId(providerId int) (int64, error) {
var count int64
err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
return count, err
}

View File

@@ -1,172 +0,0 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func init() {
Register("discord", &DiscordProvider{})
}
// DiscordProvider implements OAuth for Discord
type DiscordProvider struct{}
type discordOAuthResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type discordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func (p *DiscordProvider) GetName() string {
return "Discord"
}
func (p *DiscordProvider) IsEnabled() bool {
return system_setting.GetDiscordSettings().Enabled
}
func (p *DiscordProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: code=%s...", code[:min(len(code), 10)])
settings := system_setting.GetDiscordSettings()
redirectUri := fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress)
values := url.Values{}
values.Set("client_id", settings.ClientId)
values.Set("client_secret", settings.ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", redirectUri)
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken: redirect_uri=%s", redirectUri)
req, err := http.NewRequestWithContext(ctx, "POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken response status: %d", res.StatusCode)
var discordResponse discordOAuthResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if discordResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-Discord] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Discord"})
}
logger.LogDebug(ctx, "[OAuth-Discord] ExchangeToken success: scope=%s", discordResponse.Scope)
return &OAuthToken{
AccessToken: discordResponse.AccessToken,
TokenType: discordResponse.TokenType,
RefreshToken: discordResponse.RefreshToken,
ExpiresIn: discordResponse.ExpiresIn,
Scope: discordResponse.Scope,
IDToken: discordResponse.IDToken,
}, nil
}
func (p *DiscordProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo: fetching user info")
req, err := http.NewRequestWithContext(ctx, "GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Discord"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo response status: %d", res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo failed: status=%d", res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
var discordUser discordUser
err = json.NewDecoder(res.Body).Decode(&discordUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Discord] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
logger.LogError(ctx, "[OAuth-Discord] GetUserInfo failed: empty user fields")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Discord"})
}
logger.LogDebug(ctx, "[OAuth-Discord] GetUserInfo success: uid=%s, username=%s, name=%s", discordUser.UID, discordUser.ID, discordUser.Name)
return &OAuthUser{
ProviderUserID: discordUser.UID,
Username: discordUser.ID,
DisplayName: discordUser.Name,
}, nil
}
func (p *DiscordProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsDiscordIdAlreadyTaken(providerUserID)
}
func (p *DiscordProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.DiscordId = providerUserID
return user.FillUserByDiscordId()
}
func (p *DiscordProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.DiscordId = providerUserID
}
func (p *DiscordProvider) GetProviderPrefix() string {
return "discord_"
}

View File

@@ -1,268 +0,0 @@
package oauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
// AuthStyle defines how to send client credentials
const (
AuthStyleAutoDetect = 0 // Auto-detect based on server response
AuthStyleInParams = 1 // Send client_id and client_secret as POST parameters
AuthStyleInHeader = 2 // Send as Basic Auth header
)
// GenericOAuthProvider implements OAuth for custom/generic OAuth providers
type GenericOAuthProvider struct {
config *model.CustomOAuthProvider
}
// NewGenericOAuthProvider creates a new generic OAuth provider from config
func NewGenericOAuthProvider(config *model.CustomOAuthProvider) *GenericOAuthProvider {
return &GenericOAuthProvider{config: config}
}
func (p *GenericOAuthProvider) GetName() string {
return p.config.Name
}
func (p *GenericOAuthProvider) IsEnabled() bool {
return p.config.Enabled
}
func (p *GenericOAuthProvider) GetConfig() *model.CustomOAuthProvider {
return p.config
}
func (p *GenericOAuthProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: code=%s...", p.config.Slug, code[:min(len(code), 10)])
redirectUri := fmt.Sprintf("%s/oauth/%s", system_setting.ServerAddress, p.config.Slug)
values := url.Values{}
values.Set("grant_type", "authorization_code")
values.Set("code", code)
values.Set("redirect_uri", redirectUri)
// Determine auth style
authStyle := p.config.AuthStyle
if authStyle == AuthStyleAutoDetect {
// Default to params style for most OAuth servers
authStyle = AuthStyleInParams
}
var req *http.Request
var err error
if authStyle == AuthStyleInParams {
values.Set("client_id", p.config.ClientId)
values.Set("client_secret", p.config.ClientSecret)
}
req, err = http.NewRequestWithContext(ctx, "POST", p.config.TokenEndpoint, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
if authStyle == AuthStyleInHeader {
// Basic Auth
credentials := base64.StdEncoding.EncodeToString([]byte(p.config.ClientId + ":" + p.config.ClientSecret))
req.Header.Set("Authorization", "Basic "+credentials)
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken: token_endpoint=%s, redirect_uri=%s, auth_style=%d",
p.config.Slug, p.config.TokenEndpoint, redirectUri, authStyle)
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken error: %s", p.config.Slug, err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response status: %d", p.config.Slug, res.StatusCode)
body, err := io.ReadAll(res.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken read body error: %s", p.config.Slug, err.Error()))
return nil, err
}
bodyStr := string(body)
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
// Try to parse as JSON first
var tokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
IDToken string `json:"id_token"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
if err := json.Unmarshal(body, &tokenResponse); err != nil {
// Try to parse as URL-encoded (some OAuth servers like GitHub return this format)
parsedValues, parseErr := url.ParseQuery(bodyStr)
if parseErr != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken parse error: %s", p.config.Slug, err.Error()))
return nil, err
}
tokenResponse.AccessToken = parsedValues.Get("access_token")
tokenResponse.TokenType = parsedValues.Get("token_type")
tokenResponse.Scope = parsedValues.Get("scope")
}
if tokenResponse.Error != "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken OAuth error: %s - %s",
p.config.Slug, tokenResponse.Error, tokenResponse.ErrorDesc))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name}, tokenResponse.ErrorDesc)
}
if tokenResponse.AccessToken == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] ExchangeToken failed: empty access token", p.config.Slug))
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": p.config.Name})
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] ExchangeToken success: scope=%s", p.config.Slug, tokenResponse.Scope)
return &OAuthToken{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
ExpiresIn: tokenResponse.ExpiresIn,
Scope: tokenResponse.Scope,
IDToken: tokenResponse.IDToken,
}, nil
}
func (p *GenericOAuthProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo: fetching user info from %s", p.config.Slug, p.config.UserInfoEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", p.config.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
// Set authorization header
tokenType := token.TokenType
if tokenType == "" {
tokenType = "Bearer"
}
req.Header.Set("Authorization", fmt.Sprintf("%s %s", tokenType, token.AccessToken))
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo error: %s", p.config.Slug, err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": p.config.Name}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response status: %d", p.config.Slug, res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: status=%d", p.config.Slug, res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
body, err := io.ReadAll(res.Body)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo read body error: %s", p.config.Slug, err.Error()))
return nil, err
}
bodyStr := string(body)
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo response body: %s", p.config.Slug, bodyStr[:min(len(bodyStr), 500)])
// Extract fields using gjson (supports JSONPath-like syntax)
userId := gjson.Get(bodyStr, p.config.UserIdField).String()
username := gjson.Get(bodyStr, p.config.UsernameField).String()
displayName := gjson.Get(bodyStr, p.config.DisplayNameField).String()
email := gjson.Get(bodyStr, p.config.EmailField).String()
// If user ID field returns a number, convert it
if userId == "" {
// Try to get as number
userIdNum := gjson.Get(bodyStr, p.config.UserIdField)
if userIdNum.Exists() {
userId = userIdNum.Raw
// Remove quotes if present
userId = strings.Trim(userId, "\"")
}
}
if userId == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-Generic-%s] GetUserInfo failed: empty user ID (field: %s)", p.config.Slug, p.config.UserIdField))
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": p.config.Name})
}
logger.LogDebug(ctx, "[OAuth-Generic-%s] GetUserInfo success: id=%s, username=%s, name=%s, email=%s",
p.config.Slug, userId, username, displayName, email)
return &OAuthUser{
ProviderUserID: userId,
Username: username,
DisplayName: displayName,
Email: email,
}, nil
}
func (p *GenericOAuthProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsProviderUserIdTaken(p.config.Id, providerUserID)
}
func (p *GenericOAuthProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
foundUser, err := model.GetUserByOAuthBinding(p.config.Id, providerUserID)
if err != nil {
return err
}
*user = *foundUser
return nil
}
func (p *GenericOAuthProvider) SetProviderUserID(user *model.User, providerUserID string) {
// For generic providers, we store the binding in user_oauth_bindings table
// This is handled separately in the OAuth controller
}
func (p *GenericOAuthProvider) GetProviderPrefix() string {
return p.config.Slug + "_"
}
// GetProviderId returns the provider ID for binding purposes
func (p *GenericOAuthProvider) GetProviderId() int {
return p.config.Id
}
// IsGenericProvider returns true for generic providers
func (p *GenericOAuthProvider) IsGenericProvider() bool {
return true
}

View File

@@ -1,166 +0,0 @@
package oauth
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
func init() {
Register("github", &GitHubProvider{})
}
// GitHubProvider implements OAuth for GitHub
type GitHubProvider struct{}
type gitHubOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type gitHubUser struct {
Id int64 `json:"id"` // GitHub numeric ID (permanent, never changes)
Login string `json:"login"` // GitHub username (can be changed by user)
Name string `json:"name"`
Email string `json:"email"`
}
func (p *GitHubProvider) GetName() string {
return "GitHub"
}
func (p *GitHubProvider) IsEnabled() bool {
return common.GitHubOAuthEnabled
}
func (p *GitHubProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken: code=%s...", code[:min(len(code), 10)])
values := map[string]string{
"client_id": common.GitHubClientId,
"client_secret": common.GitHubClientSecret,
"code": code,
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken response status: %d", res.StatusCode)
var oAuthResponse gitHubOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if oAuthResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-GitHub] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "GitHub"})
}
logger.LogDebug(ctx, "[OAuth-GitHub] ExchangeToken success: scope=%s", oAuthResponse.Scope)
return &OAuthToken{
AccessToken: oAuthResponse.AccessToken,
TokenType: oAuthResponse.TokenType,
Scope: oAuthResponse.Scope,
}, nil
}
func (p *GitHubProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo: fetching user info")
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
client := http.Client{
Timeout: 20 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "GitHub"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo response status: %d", res.StatusCode)
var githubUser gitHubUser
err = json.NewDecoder(res.Body).Decode(&githubUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-GitHub] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if githubUser.Id == 0 || githubUser.Login == "" {
logger.LogError(ctx, "[OAuth-GitHub] GetUserInfo failed: empty id or login field")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "GitHub"})
}
logger.LogDebug(ctx, "[OAuth-GitHub] GetUserInfo success: id=%d, login=%s, name=%s, email=%s",
githubUser.Id, githubUser.Login, githubUser.Name, githubUser.Email)
return &OAuthUser{
ProviderUserID: strconv.FormatInt(githubUser.Id, 10), // Use numeric ID as primary identifier
Username: githubUser.Login,
DisplayName: githubUser.Name,
Email: githubUser.Email,
Extra: map[string]any{
"legacy_id": githubUser.Login, // Store login for migration from old accounts
},
}, nil
}
func (p *GitHubProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsGitHubIdAlreadyTaken(providerUserID)
}
func (p *GitHubProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.GitHubId = providerUserID
return user.FillUserByGitHubId()
}
func (p *GitHubProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.GitHubId = providerUserID
}
func (p *GitHubProvider) GetProviderPrefix() string {
return "github_"
}

View File

@@ -1,195 +0,0 @@
package oauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
func init() {
Register("linuxdo", &LinuxDOProvider{})
}
// LinuxDOProvider implements OAuth for Linux DO
type LinuxDOProvider struct{}
type linuxdoUser struct {
Id int `json:"id"`
Username string `json:"username"`
Name string `json:"name"`
Active bool `json:"active"`
TrustLevel int `json:"trust_level"`
Silenced bool `json:"silenced"`
}
func (p *LinuxDOProvider) GetName() string {
return "Linux DO"
}
func (p *LinuxDOProvider) IsEnabled() bool {
return common.LinuxDOOAuthEnabled
}
func (p *LinuxDOProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: code=%s...", code[:min(len(code), 10)])
// Get access token using Basic auth
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
// Get redirect URI from request
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
redirectURI := fmt.Sprintf("%s://%s/api/oauth/linuxdo", scheme, c.Request.Host)
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken: token_endpoint=%s, redirect_uri=%s", tokenEndpoint, redirectURI)
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", redirectURI)
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", basicAuth)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken response status: %d", res.StatusCode)
var tokenRes struct {
AccessToken string `json:"access_token"`
Message string `json:"message"`
}
if err := json.NewDecoder(res.Body).Decode(&tokenRes); err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if tokenRes.AccessToken == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] ExchangeToken failed: %s", tokenRes.Message))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "Linux DO"}, tokenRes.Message)
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] ExchangeToken success")
return &OAuthToken{
AccessToken: tokenRes.AccessToken,
}, nil
}
func (p *LinuxDOProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: user_endpoint=%s", userEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
req.Header.Set("Accept", "application/json")
client := http.Client{Timeout: 5 * time.Second}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "Linux DO"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo response status: %d", res.StatusCode)
var linuxdoUser linuxdoUser
if err := json.NewDecoder(res.Body).Decode(&linuxdoUser); err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if linuxdoUser.Id == 0 {
logger.LogError(ctx, "[OAuth-LinuxDO] GetUserInfo failed: invalid user id")
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "Linux DO"})
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo: id=%d, username=%s, name=%s, trust_level=%d, active=%v, silenced=%v",
linuxdoUser.Id, linuxdoUser.Username, linuxdoUser.Name, linuxdoUser.TrustLevel, linuxdoUser.Active, linuxdoUser.Silenced)
// Check trust level
if linuxdoUser.TrustLevel < common.LinuxDOMinimumTrustLevel {
logger.LogWarn(ctx, fmt.Sprintf("[OAuth-LinuxDO] GetUserInfo: trust level too low (required=%d, current=%d)",
common.LinuxDOMinimumTrustLevel, linuxdoUser.TrustLevel))
return nil, &TrustLevelError{
Required: common.LinuxDOMinimumTrustLevel,
Current: linuxdoUser.TrustLevel,
}
}
logger.LogDebug(ctx, "[OAuth-LinuxDO] GetUserInfo success: id=%d, username=%s", linuxdoUser.Id, linuxdoUser.Username)
return &OAuthUser{
ProviderUserID: strconv.Itoa(linuxdoUser.Id),
Username: linuxdoUser.Username,
DisplayName: linuxdoUser.Name,
Extra: map[string]any{
"trust_level": linuxdoUser.TrustLevel,
"active": linuxdoUser.Active,
"silenced": linuxdoUser.Silenced,
},
}, nil
}
func (p *LinuxDOProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsLinuxDOIdAlreadyTaken(providerUserID)
}
func (p *LinuxDOProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.LinuxDOId = providerUserID
return user.FillUserByLinuxDOId()
}
func (p *LinuxDOProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.LinuxDOId = providerUserID
}
func (p *LinuxDOProvider) GetProviderPrefix() string {
return "linuxdo_"
}
// TrustLevelError indicates the user's trust level is too low
type TrustLevelError struct {
Required int
Current int
}
func (e *TrustLevelError) Error() string {
return "trust level too low"
}

View File

@@ -1,177 +0,0 @@
package oauth
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-gonic/gin"
)
func init() {
Register("oidc", &OIDCProvider{})
}
// OIDCProvider implements OAuth for OIDC
type OIDCProvider struct{}
type oidcOAuthResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type oidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func (p *OIDCProvider) GetName() string {
return "OIDC"
}
func (p *OIDCProvider) IsEnabled() bool {
return system_setting.GetOIDCSettings().Enabled
}
func (p *OIDCProvider) ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error) {
if code == "" {
return nil, NewOAuthError(i18n.MsgOAuthInvalidCode, nil)
}
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: code=%s...", code[:min(len(code), 10)])
settings := system_setting.GetOIDCSettings()
redirectUri := fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)
values := url.Values{}
values.Set("client_id", settings.ClientId)
values.Set("client_secret", settings.ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", redirectUri)
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken: token_endpoint=%s, redirect_uri=%s", settings.TokenEndpoint, redirectUri)
req, err := http.NewRequestWithContext(ctx, "POST", settings.TokenEndpoint, strings.NewReader(values.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken response status: %d", res.StatusCode)
var oidcResponse oidcOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] ExchangeToken decode error: %s", err.Error()))
return nil, err
}
if oidcResponse.AccessToken == "" {
logger.LogError(ctx, "[OAuth-OIDC] ExchangeToken failed: empty access token")
return nil, NewOAuthError(i18n.MsgOAuthTokenFailed, map[string]any{"Provider": "OIDC"})
}
logger.LogDebug(ctx, "[OAuth-OIDC] ExchangeToken success: scope=%s", oidcResponse.Scope)
return &OAuthToken{
AccessToken: oidcResponse.AccessToken,
TokenType: oidcResponse.TokenType,
RefreshToken: oidcResponse.RefreshToken,
ExpiresIn: oidcResponse.ExpiresIn,
Scope: oidcResponse.Scope,
IDToken: oidcResponse.IDToken,
}, nil
}
func (p *OIDCProvider) GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error) {
settings := system_setting.GetOIDCSettings()
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo: userinfo_endpoint=%s", settings.UserInfoEndpoint)
req, err := http.NewRequestWithContext(ctx, "GET", settings.UserInfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo error: %s", err.Error()))
return nil, NewOAuthErrorWithRaw(i18n.MsgOAuthConnectFailed, map[string]any{"Provider": "OIDC"}, err.Error())
}
defer res.Body.Close()
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo response status: %d", res.StatusCode)
if res.StatusCode != http.StatusOK {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: status=%d", res.StatusCode))
return nil, NewOAuthError(i18n.MsgOAuthGetUserErr, nil)
}
var oidcUser oidcUser
err = json.NewDecoder(res.Body).Decode(&oidcUser)
if err != nil {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo decode error: %s", err.Error()))
return nil, err
}
if oidcUser.OpenID == "" || oidcUser.Email == "" {
logger.LogError(ctx, fmt.Sprintf("[OAuth-OIDC] GetUserInfo failed: empty fields (sub=%s, email=%s)", oidcUser.OpenID, oidcUser.Email))
return nil, NewOAuthError(i18n.MsgOAuthUserInfoEmpty, map[string]any{"Provider": "OIDC"})
}
logger.LogDebug(ctx, "[OAuth-OIDC] GetUserInfo success: sub=%s, username=%s, name=%s, email=%s", oidcUser.OpenID, oidcUser.PreferredUsername, oidcUser.Name, oidcUser.Email)
return &OAuthUser{
ProviderUserID: oidcUser.OpenID,
Username: oidcUser.PreferredUsername,
DisplayName: oidcUser.Name,
Email: oidcUser.Email,
}, nil
}
func (p *OIDCProvider) IsUserIDTaken(providerUserID string) bool {
return model.IsOidcIdAlreadyTaken(providerUserID)
}
func (p *OIDCProvider) FillUserByProviderID(user *model.User, providerUserID string) error {
user.OidcId = providerUserID
return user.FillUserByOidcId()
}
func (p *OIDCProvider) SetProviderUserID(user *model.User, providerUserID string) {
user.OidcId = providerUserID
}
func (p *OIDCProvider) GetProviderPrefix() string {
return "oidc_"
}

View File

@@ -1,36 +0,0 @@
package oauth
import (
"context"
"github.com/QuantumNous/new-api/model"
"github.com/gin-gonic/gin"
)
// Provider defines the interface for OAuth providers
type Provider interface {
// GetName returns the display name of the provider (e.g., "GitHub", "Discord")
GetName() string
// IsEnabled returns whether this OAuth provider is enabled
IsEnabled() bool
// ExchangeToken exchanges the authorization code for an access token
// The gin.Context is passed for providers that need request info (e.g., for redirect_uri)
ExchangeToken(ctx context.Context, code string, c *gin.Context) (*OAuthToken, error)
// GetUserInfo retrieves user information using the access token
GetUserInfo(ctx context.Context, token *OAuthToken) (*OAuthUser, error)
// IsUserIDTaken checks if the provider user ID is already associated with an account
IsUserIDTaken(providerUserID string) bool
// FillUserByProviderID fills the user model by provider user ID
FillUserByProviderID(user *model.User, providerUserID string) error
// SetProviderUserID sets the provider user ID on the user model
SetProviderUserID(user *model.User, providerUserID string)
// GetProviderPrefix returns the prefix for auto-generated usernames (e.g., "github_")
GetProviderPrefix() string
}

View File

@@ -1,134 +0,0 @@
package oauth
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
)
var (
providers = make(map[string]Provider)
mu sync.RWMutex
// customProviderSlugs tracks which providers are custom (can be unregistered)
customProviderSlugs = make(map[string]bool)
)
// Register registers an OAuth provider with the given name
func Register(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
}
// RegisterCustom registers a custom OAuth provider (can be unregistered later)
func RegisterCustom(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
customProviderSlugs[name] = true
}
// Unregister removes a provider from the registry
func Unregister(name string) {
mu.Lock()
defer mu.Unlock()
delete(providers, name)
delete(customProviderSlugs, name)
}
// GetProvider returns the OAuth provider for the given name
func GetProvider(name string) Provider {
mu.RLock()
defer mu.RUnlock()
return providers[name]
}
// GetAllProviders returns all registered OAuth providers
func GetAllProviders() map[string]Provider {
mu.RLock()
defer mu.RUnlock()
result := make(map[string]Provider, len(providers))
for k, v := range providers {
result[k] = v
}
return result
}
// GetEnabledCustomProviders returns all enabled custom OAuth providers
func GetEnabledCustomProviders() []*GenericOAuthProvider {
mu.RLock()
defer mu.RUnlock()
var result []*GenericOAuthProvider
for name, provider := range providers {
if customProviderSlugs[name] {
if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() {
result = append(result, gp)
}
}
}
return result
}
// IsProviderRegistered checks if a provider is registered
func IsProviderRegistered(name string) bool {
mu.RLock()
defer mu.RUnlock()
_, ok := providers[name]
return ok
}
// IsCustomProvider checks if a provider is a custom provider
func IsCustomProvider(name string) bool {
mu.RLock()
defer mu.RUnlock()
return customProviderSlugs[name]
}
// LoadCustomProviders loads all custom OAuth providers from the database
func LoadCustomProviders() error {
// First, unregister all existing custom providers
mu.Lock()
for name := range customProviderSlugs {
delete(providers, name)
}
customProviderSlugs = make(map[string]bool)
mu.Unlock()
// Load all custom providers from database
customProviders, err := model.GetAllCustomOAuthProviders()
if err != nil {
common.SysError("Failed to load custom OAuth providers: " + err.Error())
return err
}
// Register each custom provider
for _, config := range customProviders {
provider := NewGenericOAuthProvider(config)
RegisterCustom(config.Slug, provider)
common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")")
}
common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders)))
return nil
}
// ReloadCustomProviders reloads all custom OAuth providers from the database
func ReloadCustomProviders() error {
return LoadCustomProviders()
}
// RegisterOrUpdateCustomProvider registers or updates a single custom provider
func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) {
provider := NewGenericOAuthProvider(config)
mu.Lock()
defer mu.Unlock()
providers[config.Slug] = provider
customProviderSlugs[config.Slug] = true
}
// UnregisterCustomProvider unregisters a custom provider by slug
func UnregisterCustomProvider(slug string) {
Unregister(slug)
}

View File

@@ -1,59 +0,0 @@
package oauth
// OAuthToken represents the token received from OAuth provider
type OAuthToken struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Scope string `json:"scope,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
// OAuthUser represents the user info from OAuth provider
type OAuthUser struct {
// ProviderUserID is the unique identifier from the OAuth provider
ProviderUserID string
// Username is the username from the OAuth provider (e.g., GitHub login)
Username string
// DisplayName is the display name from the OAuth provider
DisplayName string
// Email is the email from the OAuth provider
Email string
// Extra contains any additional provider-specific data
Extra map[string]any
}
// OAuthError represents a translatable OAuth error
type OAuthError struct {
// MsgKey is the i18n message key
MsgKey string
// Params contains optional parameters for the message template
Params map[string]any
// RawError is the underlying error for logging purposes
RawError string
}
func (e *OAuthError) Error() string {
if e.RawError != "" {
return e.RawError
}
return e.MsgKey
}
// NewOAuthError creates a new OAuth error with the given message key
func NewOAuthError(msgKey string, params map[string]any) *OAuthError {
return &OAuthError{
MsgKey: msgKey,
Params: params,
}
}
// NewOAuthErrorWithRaw creates a new OAuth error with raw error message for logging
func NewOAuthErrorWithRaw(msgKey string, params map[string]any, rawError string) *OAuthError {
return &OAuthError{
MsgKey: msgKey,
Params: params,
RawError: rawError,
}
}

View File

@@ -224,10 +224,10 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case types.RelayFormatClaude:
if supportsAliAnthropicMessages(info.UpstreamModelName) {
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
}
return claude.ClaudeHandler(c, resp, info)
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
adaptor := openai.Adaptor{}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
@@ -41,88 +40,6 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
const clientHeaderPlaceholderPrefix = "{client_header:"
const (
headerPassthroughAllKey = "*"
headerPassthroughRegexPrefix = "re:"
headerPassthroughRegexPrefixV2 = "regex:"
)
var passthroughSkipHeaderNamesLower = map[string]struct{}{
// RFC 7230 hop-by-hop headers.
"connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"transfer-encoding": {},
"upgrade": {},
"cookie": {},
// Additional headers that should not be forwarded by name-matching passthrough rules.
"host": {},
"content-length": {},
// Do not passthrough credentials by wildcard/regex.
"authorization": {},
"x-api-key": {},
"x-goog-api-key": {},
// WebSocket handshake headers are generated by the client/dialer.
"sec-websocket-key": {},
"sec-websocket-version": {},
"sec-websocket-extensions": {},
}
var headerPassthroughRegexCache sync.Map // map[string]*regexp.Regexp
func getHeaderPassthroughRegex(pattern string) (*regexp.Regexp, error) {
pattern = strings.TrimSpace(pattern)
if pattern == "" {
return nil, errors.New("empty regex pattern")
}
if v, ok := headerPassthroughRegexCache.Load(pattern); ok {
if re, ok := v.(*regexp.Regexp); ok {
return re, nil
}
headerPassthroughRegexCache.Delete(pattern)
}
compiled, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
actual, _ := headerPassthroughRegexCache.LoadOrStore(pattern, compiled)
if re, ok := actual.(*regexp.Regexp); ok {
return re, nil
}
return compiled, nil
}
func isHeaderPassthroughRuleKey(key string) bool {
key = strings.TrimSpace(key)
if key == "" {
return false
}
if key == headerPassthroughAllKey {
return true
}
lower := strings.ToLower(key)
return strings.HasPrefix(lower, headerPassthroughRegexPrefix) || strings.HasPrefix(lower, headerPassthroughRegexPrefixV2)
}
func shouldSkipPassthroughHeader(name string) bool {
name = strings.TrimSpace(name)
if name == "" {
return true
}
lower := strings.ToLower(name)
if _, ok := passthroughSkipHeaderNamesLower[lower]; ok {
return true
}
return false
}
func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
trimmed := strings.TrimSpace(template)
if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
@@ -160,85 +77,9 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
// Supported placeholders:
// - {api_key}: resolved to the channel API key
// - {client_header:<name>}: resolved to the incoming request header value
//
// Header passthrough rules (keys only; values are ignored):
// - "*": passthrough all incoming headers by name (excluding unsafe headers)
// - "re:<regex>" / "regex:<regex>": passthrough headers whose names match the regex (Go regexp)
//
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
headerOverride := make(map[string]string)
passAll := false
var passthroughRegex []*regexp.Regexp
for k := range info.HeadersOverride {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if key == headerPassthroughAllKey {
passAll = true
continue
}
lower := strings.ToLower(key)
var pattern string
switch {
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
default:
continue
}
if pattern == "" {
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
}
compiled, err := getHeaderPassthroughRegex(pattern)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
}
passthroughRegex = append(passthroughRegex, compiled)
}
if passAll || len(passthroughRegex) > 0 {
if c == nil || c.Request == nil {
return nil, types.NewError(fmt.Errorf("missing request context for header passthrough"), types.ErrorCodeChannelHeaderOverrideInvalid)
}
for name := range c.Request.Header {
if shouldSkipPassthroughHeader(name) {
continue
}
if !passAll {
matched := false
for _, re := range passthroughRegex {
if re.MatchString(name) {
matched = true
break
}
}
if !matched {
continue
}
}
value := strings.TrimSpace(c.Request.Header.Get(name))
if value == "" {
continue
}
headerOverride[name] = value
}
}
for k, v := range info.HeadersOverride {
if isHeaderPassthroughRuleKey(k) {
continue
}
key := strings.TrimSpace(k)
if key == "" {
continue
}
str, ok := v.(string)
if !ok {
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
@@ -252,7 +93,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
continue
}
headerOverride[key] = value
headerOverride[k] = value
}
return headerOverride, nil
}

View File

@@ -49,14 +49,12 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
for i2, mediaMessage := range content {
if mediaMessage.Source != nil {
if mediaMessage.Source.Type == "url" {
// 使用统一的文件服务获取图片数据
source := types.NewURLFileSource(mediaMessage.Source.Url)
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
mediaMessage.Source.MediaType = mimeType
mediaMessage.Source.Data = base64Data
mediaMessage.Source.MediaType = fileData.MimeType
mediaMessage.Source.Data = fileData.Base64Data
mediaMessage.Source.Url = ""
mediaMessage.Source.Type = "base64"
content[i2] = mediaMessage

View File

@@ -3,6 +3,9 @@ package aws
import "strings"
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
@@ -16,7 +19,6 @@ var awsModelIDMap = map[string]string{
"claude-sonnet-4-5-20250929": "anthropic.claude-sonnet-4-5-20250929-v1:0",
"claude-haiku-4-5-20251001": "anthropic.claude-haiku-4-5-20251001-v1:0",
"claude-opus-4-5-20251101": "anthropic.claude-opus-4-5-20251101-v1:0",
"claude-opus-4-6": "anthropic.claude-opus-4-6-v1",
// Nova models
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
@@ -80,11 +82,6 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
"ap": true,
"eu": true,
},
"anthropic.claude-opus-4-6-v1": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-haiku-4-5-20251001-v1:0": {
"us": true,
"ap": true,

View File

@@ -26,7 +26,6 @@ type AwsClaudeRequest struct {
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
}
func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaudeRequest, error) {

View File

@@ -233,7 +233,7 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (*types
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
}
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body)
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, claude.RequestModeMessage)
if handlerErr != nil {
return handlerErr, nil
}
@@ -264,7 +264,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
switch v := event.(type) {
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes))
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), claude.RequestModeMessage)
if respErr != nil {
return respErr, nil
}
@@ -277,7 +277,7 @@ func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor) (
}
}
claude.HandleStreamFinalResponse(c, info, claudeInfo)
claude.HandleStreamFinalResponse(c, info, claudeInfo, claude.RequestModeMessage)
return nil, claudeInfo.Usage
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
@@ -15,7 +16,13 @@ import (
"github.com/gin-gonic/gin"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -38,10 +45,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
baseURL := fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
baseURL := ""
if a.RequestMode == RequestModeMessage {
baseURL = fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl)
} else {
baseURL = fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl)
}
if info.IsClaudeBetaQuery {
baseURL = baseURL + "?beta=true"
}
@@ -73,7 +90,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
return RequestOpenAI2ClaudeMessage(c, *request)
if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(*request), nil
} else {
return RequestOpenAI2ClaudeMessage(c, *request)
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -96,10 +117,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
return ClaudeStreamHandler(c, resp, info)
return ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
return ClaudeHandler(c, resp, info)
return ClaudeHandler(c, resp, info, a.RequestMode)
}
return
}
func (a *Adaptor) GetModelList() []string {

View File

@@ -1,6 +1,10 @@
package claude
var ModelList = []string{
"claude-instant-1.2",
"claude-2",
"claude-2.0",
"claude-2.1",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
@@ -20,11 +24,6 @@ var ModelList = []string{
"claude-sonnet-4-5-20250929-thinking",
"claude-opus-4-5-20251101",
"claude-opus-4-5-20251101-thinking",
"claude-opus-4-6",
"claude-opus-4-6-max",
"claude-opus-4-6-high",
"claude-opus-4-6-medium",
"claude-opus-4-6-low",
}
var ChannelName = "claude"

View File

@@ -17,7 +17,6 @@ import (
"github.com/QuantumNous/new-api/relay/reasonmap"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -42,6 +41,37 @@ func maybeMarkClaudeRefusal(c *gin.Context, stopReason string) {
}
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 4096
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]any, 0, len(textRequest.Tools))
@@ -142,16 +172,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(textRequest.Model); ok && effortLevel != "" &&
strings.HasPrefix(textRequest.Model, "claude-opus-4-6") {
claudeRequest.Model = baseModel
claudeRequest.Thinking = &dto.Thinking{
Type: "adaptive",
}
claudeRequest.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
// 因为BudgetTokens 必须大于1024
@@ -343,19 +364,23 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
Type: "base64",
}
// 使用统一的文件服务获取图片数据
var source *types.FileSource
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
source = types.NewURLFileSource(imageUrl.Url)
// 是url获取图片的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(c, imageUrl.Url, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
claudeMediaMessage.Source.MediaType = fileData.MimeType
claudeMediaMessage.Source.Data = fileData.Base64Data
} else {
source = types.NewBase64FileSource(imageUrl.Url, "")
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Claude")
if err != nil {
return nil, fmt.Errorf("get file data failed: %s", err.Error())
}
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = base64Data
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
@@ -390,7 +415,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
@@ -404,66 +429,74 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
}
}
var choice dto.ChatCompletionsStreamResponseChoice
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
if claudeResponse.ContentBlock != nil {
// 如果是文本块,尽可能发送首段文本(若存在)
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
} else {
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
}
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
if claudeResponse.ContentBlock != nil {
// 如果是文本块,尽可能发送首段文本(若存在)
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
}
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else {
return nil
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
}
}
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil
} else {
return nil
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
choice.Delta.ReasoningContent = claudeResponse.Delta.Thinking
}
}
} else if claudeResponse.Type == "message_delta" {
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil
} else {
return nil
}
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
@@ -474,7 +507,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.ChatCo
return &response
}
func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -492,26 +525,39 @@ func ResponseClaude2OpenAI(claudeResponse *dto.ClaudeResponse) *dto.OpenAITextRe
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
if message.Thinking != nil {
thinkingContent = *message.Thinking
if reqMode == RequestModeCompletion {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
if message.Thinking != nil {
thinkingContent = *message.Thinking
}
case "text":
responseText = message.GetText()
}
case "text":
responseText = message.GetText()
}
}
choice := dto.OpenAITextResponseChoice{
@@ -544,67 +590,71 @@ type ClaudeResponseInfo struct {
Done bool
}
func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if claudeInfo == nil {
return false
}
if claudeInfo.Usage == nil {
claudeInfo.Usage = &dto.Usage{}
}
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
}
// message_start, 获取usage
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
}
}
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
if claudeResponse.Usage != nil {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
if claudeResponse.Usage.CacheReadInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
}
if claudeResponse.Usage.CacheCreationInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
if cacheCreation5m := claudeResponse.Usage.GetCacheCreation5mTokens(); cacheCreation5m > 0 {
claudeInfo.Usage.ClaudeCacheCreation5mTokens = cacheCreation5m
}
if cacheCreation1h := claudeResponse.Usage.GetCacheCreation1hTokens(); cacheCreation1h > 0 {
claudeInfo.Usage.ClaudeCacheCreation1hTokens = cacheCreation1h
}
if claudeResponse.Usage.OutputTokens > 0 {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
return false
if claudeResponse.Type == "message_start" {
if claudeResponse.Message != nil {
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
}
// message_start, 获取usage
if claudeResponse.Message != nil && claudeResponse.Message.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Message.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Message.Usage.GetCacheCreation1hTokens()
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Thinking)
}
}
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
if claudeResponse.Usage != nil {
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
if claudeResponse.Usage.CacheReadInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
}
if claudeResponse.Usage.CacheCreationInputTokens > 0 {
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
if cacheCreation5m := claudeResponse.Usage.GetCacheCreation5mTokens(); cacheCreation5m > 0 {
claudeInfo.Usage.ClaudeCacheCreation5mTokens = cacheCreation5m
}
if cacheCreation1h := claudeResponse.Usage.GetCacheCreation1hTokens(); cacheCreation1h > 0 {
claudeInfo.Usage.ClaudeCacheCreation1hTokens = cacheCreation1h
}
if claudeResponse.Usage.OutputTokens > 0 {
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
}
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
}
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
}
}
if oaiResponse != nil {
oaiResponse.Id = claudeInfo.ResponseId
@@ -614,7 +664,7 @@ func FormatClaudeResponseInfo(claudeResponse *dto.ClaudeResponse, oaiResponse *d
return true
}
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string) *types.NewAPIError {
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
@@ -631,19 +681,24 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
maybeMarkClaudeRefusal(c, *claudeResponse.Delta.StopReason)
}
if info.RelayFormat == types.RelayFormatClaude {
FormatClaudeResponseInfo(&claudeResponse, nil, claudeInfo)
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
if claudeResponse.Message != nil {
info.UpstreamModelName = claudeResponse.Message.Model
if requestMode == RequestModeCompletion {
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
if claudeResponse.Message != nil {
info.UpstreamModelName = claudeResponse.Message.Model
}
} else if claudeResponse.Type == "content_block_delta" {
} else if claudeResponse.Type == "message_delta" {
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == types.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(&claudeResponse)
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(&claudeResponse, response, claudeInfo) {
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
return nil
}
@@ -655,15 +710,20 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo) {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysLog("claude response usage is not complete, maybe upstream error")
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysLog("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
if info.RelayFormat == types.RelayFormatClaude {
@@ -680,7 +740,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
@@ -690,7 +750,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
}
var err *types.NewAPIError
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data)
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil {
return false
}
@@ -700,11 +760,11 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
return nil, err
}
HandleStreamFinalResponse(c, info, claudeInfo)
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
return claudeInfo.Usage, nil
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte) *types.NewAPIError {
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.Unmarshal(data, &claudeResponse)
if err != nil {
@@ -714,22 +774,26 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
maybeMarkClaudeRefusal(c, claudeResponse.StopReason)
if claudeInfo.Usage == nil {
claudeInfo.Usage = &dto.Usage{}
}
if claudeResponse.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(c, claudeResponse.Completion, info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
if claudeInfo.Usage == nil {
claudeInfo.Usage = &dto.Usage{}
}
if claudeResponse.Usage != nil {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
claudeInfo.Usage.ClaudeCacheCreation5mTokens = claudeResponse.Usage.GetCacheCreation5mTokens()
claudeInfo.Usage.ClaudeCacheCreation1hTokens = claudeResponse.Usage.GetCacheCreation1hTokens()
}
}
var responseData []byte
switch info.RelayFormat {
case types.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(&claudeResponse)
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
@@ -747,7 +811,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.Usage, *types.NewAPIError) {
defer service.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
@@ -764,7 +828,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody)
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
if handleErr != nil {
return nil, handleErr
}

View File

@@ -90,12 +90,6 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
}
}
// Codex backend requires the `instructions` field to be present.
// Keep it consistent with Codex CLI behavior by defaulting to an empty string.
if len(request.Instructions) == 0 {
request.Instructions = json.RawMessage(`""`)
}
if isCompact {
return request, nil
}
@@ -178,15 +172,5 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
req.Set("originator", "codex_cli_rs")
}
// chatgpt.com/backend-api/codex/responses is strict about Content-Type.
// Clients may omit it or include parameters like `application/json; charset=utf-8`,
// which can be rejected by the upstream. Force the exact media type.
req.Set("Content-Type", "application/json")
if info.IsStream {
req.Set("Accept", "text/event-stream")
} else if req.Get("Accept") == "" {
req.Set("Accept", "application/json")
}
return nil
}

View File

@@ -8,7 +8,7 @@ import (
var baseModelList = []string{
"gpt-5", "gpt-5-codex", "gpt-5-codex-mini",
"gpt-5.1", "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini",
"gpt-5.2", "gpt-5.2-codex", "gpt-5.3-codex",
"gpt-5.2", "gpt-5.2-codex",
}
var ModelList = withCompactModelSuffix(baseModelList)

View File

@@ -96,9 +96,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {
return claude.ClaudeHandler(c, resp, info)
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
adaptor := openai.Adaptor{}

View File

@@ -466,6 +466,7 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
}
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
if part.Text == "" {
@@ -506,6 +507,10 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
}
// 提取 data URL (从 "](" 后面开始,到 ")" 之前)
dataUrl := text[bracketIdx+2 : closeIdx]
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
format, base64String, err := service.DecodeBase64FileData(dataUrl)
if err != nil {
return nil, fmt.Errorf("decode markdown base64 image data failed: %s", err.Error())
@@ -530,58 +535,69 @@ func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
})
}
} else if part.Type == dto.ContentTypeImageURL {
// 使用统一的文件服务获取图片数据
var source *types.FileSource
imageUrl := part.GetImageMedia().Url
if strings.HasPrefix(imageUrl, "http") {
source = types.NewURLFileSource(imageUrl)
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url获取文件的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(c, part.GetImageMedia().Url, "formatting image for Gemini")
if err != nil {
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
}
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
url := part.GetImageMedia().Url
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: fileData.MimeType, // 使用原始的 MimeType因为大小写可能对API有意义
Data: fileData.Base64Data,
},
})
} else {
source = types.NewBase64FileSource(imageUrl, "")
format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
}
base64Data, mimeType, err := service.GetBase64Data(c, source, "formatting image for Gemini")
if err != nil {
return nil, fmt.Errorf("get file data from '%s' failed: %w", source.GetIdentifier(), err)
}
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(mimeType)]; !ok {
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", mimeType, source.GetIdentifier(), getSupportedMimeTypesList())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
fileSource := types.NewBase64FileSource(part.GetFile().FileData, "")
base64Data, mimeType, err := service.GetBase64Data(c, fileSource, "formatting file for Gemini")
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
MimeType: format,
Data: base64String,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
audioSource := types.NewBase64FileSource(part.GetInputAudio().Data, "audio/"+part.GetInputAudio().Format)
base64Data, mimeType, err := service.GetBase64Data(c, audioSource, "formatting audio for Gemini")
base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, dto.GeminiPart{
InlineData: &dto.GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
MimeType: "audio/" + part.GetInputAudio().Format,
Data: base64String,
},
})
}
@@ -972,9 +988,11 @@ func unescapeMapOrSlice(data interface{}) interface{} {
func getResponseToolCall(item *dto.GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
// 移除 unescapeMapOrSlice 调用,直接使用 json.Marshal
// JSON 序列化/反序列化已经正确处理了转义字符
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
} else {
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
}
if err != nil {
return nil
@@ -1258,7 +1276,8 @@ func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
if usage.CompletionTokens <= 0 {
if info.ReceivedResponseCount > 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
} else {
usage = &dto.Usage{}

View File

@@ -103,9 +103,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {
return claude.ClaudeHandler(c, resp, info)
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
adaptor := openai.Adaptor{}

View File

@@ -99,16 +99,19 @@ func openAIChatToOllamaChat(c *gin.Context, r *dto.GeneralOpenAIRequest) (*Ollam
if part.Type == dto.ContentTypeImageURL {
img := part.GetImageMedia()
if img != nil && img.Url != "" {
// 使用统一的文件服务获取图片数据
var source *types.FileSource
var base64Data string
if strings.HasPrefix(img.Url, "http") {
source = types.NewURLFileSource(img.Url)
fileData, err := service.GetFileBase64FromUrl(c, img.Url, "fetch image for ollama chat")
if err != nil {
return nil, err
}
base64Data = fileData.Base64Data
} else if strings.HasPrefix(img.Url, "data:") {
if idx := strings.Index(img.Url, ","); idx != -1 && idx+1 < len(img.Url) {
base64Data = img.Url[idx+1:]
}
} else {
source = types.NewBase64FileSource(img.Url, "")
}
base64Data, _, err := service.GetBase64Data(c, source, "fetch image for ollama chat")
if err != nil {
return nil, err
base64Data = img.Url
}
if base64Data != "" {
images = append(images, base64Data)

View File

@@ -585,9 +585,6 @@ func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommo
}
request.Model = originModel
}
if info != nil && request.Reasoning != nil && request.Reasoning.Effort != "" {
info.ReasoningEffort = request.Reasoning.Effort
}
return request, nil
}

View File

@@ -18,26 +18,6 @@ import (
"github.com/gin-gonic/gin"
)
func responsesStreamIndexKey(itemID string, idx *int) string {
if itemID == "" {
return ""
}
if idx == nil {
return itemID
}
return fmt.Sprintf("%s:%d", itemID, *idx)
}
func stringDeltaFromPrefix(prev string, next string) string {
if next == "" {
return ""
}
if prev != "" && strings.HasPrefix(next, prev) {
return next[len(prev):]
}
return next
}
func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
@@ -106,7 +86,6 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
toolCallArgsByID := make(map[string]string)
toolCallNameSent := make(map[string]bool)
toolCallCanonicalIDByItemID := make(map[string]string)
//reasoningSummaryTextByKey := make(map[string]string)
sendStartIfNeeded := func() bool {
if sentStart {
@@ -120,66 +99,6 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
return true
}
//sendReasoningDelta := func(delta string) bool {
// if delta == "" {
// return true
// }
// if !sendStartIfNeeded() {
// return false
// }
//
// usageText.WriteString(delta)
// chunk := &dto.ChatCompletionsStreamResponse{
// Id: responseId,
// Object: "chat.completion.chunk",
// Created: createAt,
// Model: model,
// Choices: []dto.ChatCompletionsStreamResponseChoice{
// {
// Index: 0,
// Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
// ReasoningContent: &delta,
// },
// },
// },
// }
// if err := helper.ObjectData(c, chunk); err != nil {
// streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
// return false
// }
// return true
//}
sendReasoningSummaryDelta := func(delta string) bool {
if delta == "" {
return true
}
if !sendStartIfNeeded() {
return false
}
usageText.WriteString(delta)
chunk := &dto.ChatCompletionsStreamResponse{
Id: responseId,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
Choices: []dto.ChatCompletionsStreamResponseChoice{
{
Index: 0,
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
ReasoningContent: &delta,
},
},
},
}
if err := helper.ObjectData(c, chunk); err != nil {
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
return false
}
return true
}
sendToolCallDelta := func(callID string, name string, argsDelta string) bool {
if callID == "" {
return true
@@ -269,37 +188,6 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
}
}
//case "response.reasoning_text.delta":
//if !sendReasoningDelta(streamResp.Delta) {
// return false
//}
//case "response.reasoning_text.done":
case "response.reasoning_summary_text.delta":
if !sendReasoningSummaryDelta(streamResp.Delta) {
return false
}
case "response.reasoning_summary_text.done":
//case "response.reasoning_summary_part.added", "response.reasoning_summary_part.done":
// key := responsesStreamIndexKey(strings.TrimSpace(streamResp.ItemID), streamResp.SummaryIndex)
// if key == "" || streamResp.Part == nil {
// break
// }
// // Only handle summary text parts, ignore other part types.
// if streamResp.Part.Type != "" && streamResp.Part.Type != "summary_text" {
// break
// }
// prev := reasoningSummaryTextByKey[key]
// next := streamResp.Part.Text
// delta := stringDeltaFromPrefix(prev, next)
// reasoningSummaryTextByKey[key] = next
// if !sendReasoningSummaryDelta(delta) {
// return false
// }
case "response.output_text.delta":
if !sendStartIfNeeded() {
return false

View File

@@ -42,7 +42,6 @@ var claudeModelMap = map[string]string{
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5@20250929",
"claude-haiku-4-5-20251001": "claude-haiku-4-5@20251001",
"claude-opus-4-5-20251101": "claude-opus-4-5@20251101",
"claude-opus-4-6": "claude-opus-4-6",
}
const anthropicVersion = "vertex-2023-10-16"
@@ -368,7 +367,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
return claude.ClaudeStreamHandler(c, resp, info)
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationStreamHandler(c, info, resp)
@@ -381,7 +380,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else {
switch a.RequestMode {
case RequestModeClaude:
return claude.ClaudeHandler(c, resp, info)
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
return gemini.GeminiTextGenerationHandler(c, info, resp)

View File

@@ -1,8 +1,6 @@
package vertex
import (
"encoding/json"
"github.com/QuantumNous/new-api/dto"
)
@@ -19,7 +17,6 @@ type VertexAIClaudeRequest struct {
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
OutputConfig json.RawMessage `json:"output_config,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
@@ -36,6 +33,5 @@ func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
OutputConfig: req.OutputConfig,
}
}

View File

@@ -348,9 +348,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayFormat == types.RelayFormatClaude {
if _, ok := channelconstant.ChannelSpecialBases[info.ChannelBaseUrl]; ok {
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
}
return claude.ClaudeHandler(c, resp, info)
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
}

View File

@@ -110,9 +110,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayFormat {
case types.RelayFormatClaude:
if info.IsStream {
return claude.ClaudeStreamHandler(c, resp, info)
return claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {
return claude.ClaudeHandler(c, resp, info)
return claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
if info.RelayMode == relayconstant.RelayModeImagesGenerations {

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -15,7 +14,6 @@ import (
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/setting/reasoning"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
@@ -51,17 +49,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
request.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(request.Model))
}
if baseModel, effortLevel, ok := reasoning.TrimEffortSuffix(request.Model); ok && effortLevel != "" &&
strings.HasPrefix(request.Model, "claude-opus-4-6") {
request.Model = baseModel
request.Thinking = &dto.Thinking{
Type: "adaptive",
}
request.OutputConfig = json.RawMessage(fmt.Sprintf(`{"effort":"%s"}`, effortLevel))
request.TopP = 0
request.Temperature = common.GetPointer[float64](1.0)
info.UpstreamModelName = request.Model
} else if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(request.Model, "-thinking") {
if request.Thinking == nil {
// 因为BudgetTokens 必须大于1024

Some files were not shown because too many files have changed in this diff Show More