diff --git a/README.md b/README.md index 7ddd8134..5345042b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,10 @@ # Claude Relay Service +> [!CAUTION] +> **安全更新通知**:v1.1.240 及以下版本存在严重的管理员认证绕过漏洞,攻击者可未授权访问管理面板。 +> +> **请立即更新到 v1.1.241+ 版本**,或迁移到新一代项目 **[CRS 2.0 (sub2api)](https://github.com/Wei-Shaw/sub2api)** +
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) @@ -426,6 +431,8 @@ export ANTHROPIC_MODEL="gemini-2.5-pro" 如果该文件不存在,请手动创建。Windows 用户路径为 `C:\Users\你的用户名\.claude\config.json`。 +> 💡 **IntelliJ IDEA 用户推荐**:[Claude Code Plus](https://github.com/touwaeriol/claude-code-plus) - 将 Claude Code 直接集成到 IDE,支持代码理解、文件读写、命令执行。插件市场搜索 `Claude Code Plus` 即可安装。 + **Gemini CLI 设置环境变量:** **方式一(推荐):通过 Gemini Assist API 方式访问** diff --git a/README_EN.md b/README_EN.md index f9a0e1c5..037d81ac 100644 --- a/README_EN.md +++ b/README_EN.md @@ -1,5 +1,10 @@ # Claude Relay Service +> [!CAUTION] +> **Security Update**: v1.1.240 and below contain a critical admin authentication bypass vulnerability allowing unauthorized access to the admin panel. +> +> **Please update to v1.1.241+ immediately**, or migrate to the next-generation project **[CRS 2.0 (sub2api)](https://github.com/Wei-Shaw/sub2api)** +
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) diff --git a/VERSION b/VERSION index f1bc9377..9c6cacb8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.235 +1.1.241 diff --git a/config/config.example.js b/config/config.example.js index 9cf26002..e5e0c340 100644 --- a/config/config.example.js +++ b/config/config.example.js @@ -205,6 +205,14 @@ const config = { hotReload: process.env.HOT_RELOAD === 'true' }, + // 💰 账户余额相关配置 + accountBalance: { + // 是否允许执行自定义余额脚本(安全开关) + // 说明:脚本能力可发起任意 HTTP 请求并在服务端执行 extractor 逻辑,建议仅在受控环境开启 + // 默认保持开启;如需禁用请显式设置:BALANCE_SCRIPT_ENABLED=false + enableBalanceScript: process.env.BALANCE_SCRIPT_ENABLED !== 'false' + }, + // 📬 用户消息队列配置 // 优化说明:锁在请求发送成功后立即释放(而非请求完成后),因为 Claude API 限流基于请求发送时刻计算 userMessageQueue: { diff --git a/package-lock.json b/package-lock.json index 4fa299a4..d9ebcff0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -26,6 +26,7 @@ "ioredis": "^5.3.2", "ldapjs": "^3.0.7", "morgan": "^1.10.0", + "node-cron": "^4.2.1", "nodemailer": "^7.0.6", "ora": "^5.4.1", "rate-limiter-flexible": "^5.0.5", @@ -7028,6 +7029,15 @@ "node": ">= 0.6" } }, + "node_modules/node-cron": { + "version": "4.2.1", + "resolved": "https://registry.npmmirror.com/node-cron/-/node-cron-4.2.1.tgz", + "integrity": "sha512-lgimEHPE/QDgFlywTd8yTR61ptugX3Qer29efeyWw2rv259HtGBNn1vZVmp8lB9uo9wC0t/AT4iGqXxia+CJFg==", + "license": "ISC", + "engines": { + "node": ">=6.0.0" + } + }, "node_modules/node-domexception": { "version": "1.0.0", "resolved": "https://registry.npmmirror.com/node-domexception/-/node-domexception-1.0.0.tgz", diff --git a/package.json b/package.json index 2b7ffa25..6ef88e60 100644 --- a/package.json +++ b/package.json @@ -65,6 +65,7 @@ "ioredis": "^5.3.2", "ldapjs": "^3.0.7", "morgan": "^1.10.0", + "node-cron": "^4.2.1", "nodemailer": "^7.0.6", "ora": "^5.4.1", "rate-limiter-flexible": "^5.0.5", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index dafee4e7..9e8dc0fb 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -59,6 +59,9 @@ importers: morgan: specifier: ^1.10.0 version: 1.10.1 + node-cron: + specifier: ^4.2.1 + version: 4.2.1 nodemailer: specifier: ^7.0.6 version: 7.0.11 @@ -108,6 +111,9 @@ importers: prettier: specifier: ^3.6.2 version: 3.7.4 + prettier-plugin-tailwindcss: + specifier: ^0.7.2 + version: 0.7.2(prettier@3.7.4) supertest: specifier: ^6.3.3 version: 6.3.4 @@ -2144,6 +2150,10 @@ packages: resolution: {integrity: sha512-myRT3DiWPHqho5PrJaIRyaMv2kgYf0mUVgBNOYMuCH5Ki1yEiQaf/ZJuQ62nvpc44wL5WDbTX7yGJi1Neevw8w==} engines: {node: '>= 0.6'} + node-cron@4.2.1: + resolution: {integrity: sha512-lgimEHPE/QDgFlywTd8yTR61ptugX3Qer29efeyWw2rv259HtGBNn1vZVmp8lB9uo9wC0t/AT4iGqXxia+CJFg==} + engines: {node: '>=6.0.0'} + node-domexception@1.0.0: resolution: {integrity: sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==} engines: {node: '>=10.5.0'} @@ -2302,6 +2312,61 @@ packages: resolution: {integrity: sha512-GbK2cP9nraSSUF9N2XwUwqfzlAFlMNYYl+ShE/V+H8a9uNl/oUqB1w2EL54Jh0OlyRSd8RfWYJ3coVS4TROP2w==} engines: {node: '>=6.0.0'} + prettier-plugin-tailwindcss@0.7.2: + resolution: {integrity: sha512-LkphyK3Fw+q2HdMOoiEHWf93fNtYJwfamoKPl7UwtjFQdei/iIBoX11G6j706FzN3ymX9mPVi97qIY8328vdnA==} + engines: {node: '>=20.19'} + peerDependencies: + '@ianvs/prettier-plugin-sort-imports': '*' + '@prettier/plugin-hermes': '*' + '@prettier/plugin-oxc': '*' + '@prettier/plugin-pug': '*' + '@shopify/prettier-plugin-liquid': '*' + '@trivago/prettier-plugin-sort-imports': '*' + '@zackad/prettier-plugin-twig': '*' + prettier: ^3.0 + prettier-plugin-astro: '*' + prettier-plugin-css-order: '*' + prettier-plugin-jsdoc: '*' + prettier-plugin-marko: '*' + prettier-plugin-multiline-arrays: '*' + prettier-plugin-organize-attributes: '*' + prettier-plugin-organize-imports: '*' + prettier-plugin-sort-imports: '*' + prettier-plugin-svelte: '*' + peerDependenciesMeta: + '@ianvs/prettier-plugin-sort-imports': + optional: true + '@prettier/plugin-hermes': + optional: true + '@prettier/plugin-oxc': + optional: true + '@prettier/plugin-pug': + optional: true + '@shopify/prettier-plugin-liquid': + optional: true + '@trivago/prettier-plugin-sort-imports': + optional: true + '@zackad/prettier-plugin-twig': + optional: true + prettier-plugin-astro: + optional: true + prettier-plugin-css-order: + optional: true + prettier-plugin-jsdoc: + optional: true + prettier-plugin-marko: + optional: true + prettier-plugin-multiline-arrays: + optional: true + prettier-plugin-organize-attributes: + optional: true + prettier-plugin-organize-imports: + optional: true + prettier-plugin-sort-imports: + optional: true + prettier-plugin-svelte: + optional: true + prettier@3.7.4: resolution: {integrity: sha512-v6UNi1+3hSlVvv8fSaoUbggEM5VErKmmpGA7Pl3HF8V6uKY7rvClBOJlH6yNwQtfTueNkGVpOv/mtWL9L4bgRA==} engines: {node: '>=14'} @@ -5692,6 +5757,8 @@ snapshots: negotiator@0.6.4: {} + node-cron@4.2.1: {} + node-domexception@1.0.0: {} node-fetch@3.3.2: @@ -5840,6 +5907,10 @@ snapshots: dependencies: fast-diff: 1.3.0 + prettier-plugin-tailwindcss@0.7.2(prettier@3.7.4): + dependencies: + prettier: 3.7.4 + prettier@3.7.4: {} pretty-format@29.7.0: diff --git a/src/app.js b/src/app.js index 1ea2f325..f83be464 100644 --- a/src/app.js +++ b/src/app.js @@ -52,6 +52,16 @@ class Application { await redis.connect() logger.success('✅ Redis connected successfully') + // 💳 初始化账户余额查询服务(Provider 注册) + try { + const accountBalanceService = require('./services/accountBalanceService') + const { registerAllProviders } = require('./services/balanceProviders') + registerAllProviders(accountBalanceService) + logger.info('✅ 账户余额查询服务已初始化') + } catch (error) { + logger.warn('⚠️ 账户余额查询服务初始化失败:', error.message) + } + // 💰 初始化价格服务 logger.info('🔄 Initializing pricing service...') await pricingService.initialize() @@ -68,6 +78,10 @@ class Application { logger.info('🔄 Initializing admin credentials...') await this.initializeAdmin() + // 🔒 安全启动:清理无效/伪造的管理员会话 + logger.info('🔒 Cleaning up invalid admin sessions...') + await this.cleanupInvalidSessions() + // 💰 初始化费用数据 logger.info('💰 Checking cost data initialization...') const costInitService = require('./services/costInitService') @@ -445,6 +459,54 @@ class Application { } } + // 🔒 清理无效/伪造的管理员会话(安全启动检查) + async cleanupInvalidSessions() { + try { + const client = redis.getClient() + + // 获取所有 session:* 键 + const sessionKeys = await client.keys('session:*') + + let validCount = 0 + let invalidCount = 0 + + for (const key of sessionKeys) { + // 跳过 admin_credentials(系统凭据) + if (key === 'session:admin_credentials') { + continue + } + + const sessionData = await client.hgetall(key) + + // 检查会话完整性:必须有 username 和 loginTime + const hasUsername = !!sessionData.username + const hasLoginTime = !!sessionData.loginTime + + if (!hasUsername || !hasLoginTime) { + // 无效会话 - 可能是漏洞利用创建的伪造会话 + invalidCount++ + logger.security( + `🔒 Removing invalid session: ${key} (username: ${hasUsername}, loginTime: ${hasLoginTime})` + ) + await client.del(key) + } else { + validCount++ + } + } + + if (invalidCount > 0) { + logger.security(`🔒 Startup security check: Removed ${invalidCount} invalid sessions`) + } + + logger.success( + `✅ Session cleanup completed: ${validCount} valid, ${invalidCount} invalid removed` + ) + } catch (error) { + // 清理失败不应阻止服务启动 + logger.error('❌ Failed to cleanup invalid sessions:', error.message) + } + } + // 🔍 Redis健康检查 async checkRedisHealth() { try { @@ -600,10 +662,11 @@ class Application { const now = Date.now() let totalCleaned = 0 + let legacyCleaned = 0 // 使用 Lua 脚本批量清理所有过期项 for (const key of keys) { - // 跳过非 Sorted Set 类型的键(这些键有各自的清理逻辑) + // 跳过已知非 Sorted Set 类型的键(这些键有各自的清理逻辑) // - concurrency:queue:stats:* 是 Hash 类型 // - concurrency:queue:wait_times:* 是 List 类型 // - concurrency:queue:* (不含stats/wait_times) 是 String 类型 @@ -618,11 +681,21 @@ class Application { } try { - const cleaned = await redis.client.eval( + // 使用原子 Lua 脚本:先检查类型,再执行清理 + // 返回值:0 = 正常清理无删除,1 = 清理后删除空键,-1 = 遗留键已删除 + const result = await redis.client.eval( ` local key = KEYS[1] local now = tonumber(ARGV[1]) + -- 先检查键类型,只对 Sorted Set 执行清理 + local keyType = redis.call('TYPE', key) + if keyType.ok ~= 'zset' then + -- 非 ZSET 类型的遗留键,直接删除 + redis.call('DEL', key) + return -1 + end + -- 清理过期项 redis.call('ZREMRANGEBYSCORE', key, '-inf', now) @@ -641,8 +714,10 @@ class Application { key, now ) - if (cleaned === 1) { + if (result === 1) { totalCleaned++ + } else if (result === -1) { + legacyCleaned++ } } catch (error) { logger.error(`❌ Failed to clean concurrency key ${key}:`, error) @@ -652,6 +727,9 @@ class Application { if (totalCleaned > 0) { logger.info(`🔢 Concurrency cleanup: cleaned ${totalCleaned} expired keys`) } + if (legacyCleaned > 0) { + logger.warn(`🧹 Concurrency cleanup: removed ${legacyCleaned} legacy keys (wrong type)`) + } } catch (error) { logger.error('❌ Concurrency cleanup task failed:', error) } @@ -680,6 +758,19 @@ class Application { '🚦 Skipping concurrency queue cleanup on startup (CLEAR_CONCURRENCY_QUEUES_ON_STARTUP=false)' ) } + + // 🧪 启动账户定时测试调度器 + // 根据配置定期测试账户连通性并保存测试历史 + const accountTestSchedulerEnabled = + process.env.ACCOUNT_TEST_SCHEDULER_ENABLED !== 'false' && + config.accountTestScheduler?.enabled !== false + if (accountTestSchedulerEnabled) { + const accountTestSchedulerService = require('./services/accountTestSchedulerService') + accountTestSchedulerService.start() + logger.info('🧪 Account test scheduler service started') + } else { + logger.info('🧪 Account test scheduler service disabled') + } } setupGracefulShutdown() { @@ -734,6 +825,15 @@ class Application { logger.error('❌ Error stopping cost rank service:', error) } + // 停止账户定时测试调度器 + try { + const accountTestSchedulerService = require('./services/accountTestSchedulerService') + accountTestSchedulerService.stop() + logger.info('🧪 Account test scheduler service stopped') + } catch (error) { + logger.error('❌ Error stopping account test scheduler service:', error) + } + // 🔢 清理所有并发计数(Phase 1 修复:防止重启泄漏) try { logger.info('🔢 Cleaning up all concurrency counters...') diff --git a/src/handlers/geminiHandlers.js b/src/handlers/geminiHandlers.js index dc7dc676..05e3fd25 100644 --- a/src/handlers/geminiHandlers.js +++ b/src/handlers/geminiHandlers.js @@ -87,8 +87,7 @@ function generateSessionHash(req) { * 检查 API Key 权限 */ function checkPermissions(apiKeyData, requiredPermission = 'gemini') { - const permissions = apiKeyData?.permissions || 'all' - return permissions === 'all' || permissions === requiredPermission + return apiKeyService.hasPermission(apiKeyData?.permissions, requiredPermission) } /** diff --git a/src/middleware/auth.js b/src/middleware/auth.js index 2af4ac4d..44e3cb37 100644 --- a/src/middleware/auth.js +++ b/src/middleware/auth.js @@ -1389,6 +1389,18 @@ const authenticateAdmin = async (req, res, next) => { }) } + // 🔒 安全修复:验证会话必须字段(防止伪造会话绕过认证) + if (!adminSession.username || !adminSession.loginTime) { + logger.security( + `🔒 Corrupted admin session from ${req.ip || 'unknown'} - missing required fields (username: ${!!adminSession.username}, loginTime: ${!!adminSession.loginTime})` + ) + await redis.deleteSession(token) // 清理无效/伪造的会话 + return res.status(401).json({ + error: 'Invalid session', + message: 'Session data corrupted or incomplete' + }) + } + // 检查会话活跃性(可选:检查最后活动时间) const now = new Date() const lastActivity = new Date(adminSession.lastActivity || adminSession.loginTime) @@ -1744,9 +1756,13 @@ const requestLogger = (req, res, next) => { const referer = req.get('Referer') || 'none' // 记录请求开始 + const isDebugRoute = req.originalUrl.includes('event_logging') if (req.originalUrl !== '/health') { - // 避免健康检查日志过多 - logger.info(`▶️ [${requestId}] ${req.method} ${req.originalUrl} | IP: ${clientIP}`) + if (isDebugRoute) { + logger.debug(`▶️ [${requestId}] ${req.method} ${req.originalUrl} | IP: ${clientIP}`) + } else { + logger.info(`▶️ [${requestId}] ${req.method} ${req.originalUrl} | IP: ${clientIP}`) + } } res.on('finish', () => { @@ -1778,7 +1794,14 @@ const requestLogger = (req, res, next) => { logMetadata ) } else if (req.originalUrl !== '/health') { - logger.request(req.method, req.originalUrl, res.statusCode, duration, logMetadata) + if (isDebugRoute) { + logger.debug( + `🟢 ${req.method} ${req.originalUrl} - ${res.statusCode} (${duration}ms)`, + logMetadata + ) + } else { + logger.request(req.method, req.originalUrl, res.statusCode, duration, logMetadata) + } } // API Key相关日志 diff --git a/src/models/redis.js b/src/models/redis.js index b75c0936..e69ba727 100644 --- a/src/models/redis.js +++ b/src/models/redis.js @@ -96,7 +96,25 @@ class RedisClient { logger.warn('⚠️ Redis connection closed') }) - await this.client.connect() + // 只有在 lazyConnect 模式下才需要手动调用 connect() + // 如果 Redis 已经连接或正在连接中,则跳过 + if ( + this.client.status !== 'connecting' && + this.client.status !== 'connect' && + this.client.status !== 'ready' + ) { + await this.client.connect() + } else { + // 等待 ready 状态 + await new Promise((resolve, reject) => { + if (this.client.status === 'ready') { + resolve() + } else { + this.client.once('ready', resolve) + this.client.once('error', reject) + } + }) + } return this.client } catch (error) { logger.error('💥 Failed to connect to Redis:', error) @@ -1503,6 +1521,123 @@ class RedisClient { return await this.client.del(key) } + // 💰 账户余额缓存(API 查询结果) + async setAccountBalance(platform, accountId, balanceData, ttl = 3600) { + const key = `account_balance:${platform}:${accountId}` + + const payload = { + balance: + balanceData && balanceData.balance !== null && balanceData.balance !== undefined + ? String(balanceData.balance) + : '', + currency: balanceData?.currency || 'USD', + lastRefreshAt: balanceData?.lastRefreshAt || new Date().toISOString(), + queryMethod: balanceData?.queryMethod || 'api', + status: balanceData?.status || 'success', + errorMessage: balanceData?.errorMessage || balanceData?.error || '', + rawData: balanceData?.rawData ? JSON.stringify(balanceData.rawData) : '', + quota: balanceData?.quota ? JSON.stringify(balanceData.quota) : '' + } + + await this.client.hset(key, payload) + await this.client.expire(key, ttl) + } + + async getAccountBalance(platform, accountId) { + const key = `account_balance:${platform}:${accountId}` + const [data, ttlSeconds] = await Promise.all([this.client.hgetall(key), this.client.ttl(key)]) + + if (!data || Object.keys(data).length === 0) { + return null + } + + let rawData = null + if (data.rawData) { + try { + rawData = JSON.parse(data.rawData) + } catch (error) { + rawData = null + } + } + + let quota = null + if (data.quota) { + try { + quota = JSON.parse(data.quota) + } catch (error) { + quota = null + } + } + + return { + balance: data.balance ? parseFloat(data.balance) : null, + currency: data.currency || 'USD', + lastRefreshAt: data.lastRefreshAt || null, + queryMethod: data.queryMethod || null, + status: data.status || null, + errorMessage: data.errorMessage || '', + rawData, + quota, + ttlSeconds: Number.isFinite(ttlSeconds) ? ttlSeconds : null + } + } + + // 📊 账户余额缓存(本地统计) + async setLocalBalance(platform, accountId, statisticsData, ttl = 300) { + const key = `account_balance_local:${platform}:${accountId}` + + await this.client.hset(key, { + estimatedBalance: JSON.stringify(statisticsData || {}), + lastCalculated: new Date().toISOString() + }) + await this.client.expire(key, ttl) + } + + async getLocalBalance(platform, accountId) { + const key = `account_balance_local:${platform}:${accountId}` + const data = await this.client.hgetall(key) + + if (!data || !data.estimatedBalance) { + return null + } + + try { + return JSON.parse(data.estimatedBalance) + } catch (error) { + return null + } + } + + async deleteAccountBalance(platform, accountId) { + const key = `account_balance:${platform}:${accountId}` + const localKey = `account_balance_local:${platform}:${accountId}` + await this.client.del(key, localKey) + } + + // 🧩 账户余额脚本配置 + async setBalanceScriptConfig(platform, accountId, scriptConfig) { + const key = `account_balance_script:${platform}:${accountId}` + await this.client.set(key, JSON.stringify(scriptConfig || {})) + } + + async getBalanceScriptConfig(platform, accountId) { + const key = `account_balance_script:${platform}:${accountId}` + const raw = await this.client.get(key) + if (!raw) { + return null + } + try { + return JSON.parse(raw) + } catch (error) { + return null + } + } + + async deleteBalanceScriptConfig(platform, accountId) { + const key = `account_balance_script:${platform}:${accountId}` + return await this.client.del(key) + } + // 📈 系统统计 async getSystemStats() { const keys = await Promise.all([ @@ -2122,6 +2257,27 @@ class RedisClient { const results = [] for (const key of keys) { + // 跳过已知非 Sorted Set 类型的键 + // - concurrency:queue:stats:* 是 Hash 类型 + // - concurrency:queue:wait_times:* 是 List 类型 + // - concurrency:queue:* (不含stats/wait_times) 是 String 类型 + if ( + key.startsWith('concurrency:queue:stats:') || + key.startsWith('concurrency:queue:wait_times:') || + (key.startsWith('concurrency:queue:') && + !key.includes(':stats:') && + !key.includes(':wait_times:')) + ) { + continue + } + + // 检查键类型,只处理 Sorted Set + const keyType = await client.type(key) + if (keyType !== 'zset') { + logger.debug(`🔢 getAllConcurrencyStatus skipped non-zset key: ${key} (type: ${keyType})`) + continue + } + // 提取 apiKeyId(去掉 concurrency: 前缀) const apiKeyId = key.replace('concurrency:', '') @@ -2184,6 +2340,23 @@ class RedisClient { } } + // 检查键类型,只处理 Sorted Set + const keyType = await client.type(key) + if (keyType !== 'zset') { + logger.warn( + `⚠️ getConcurrencyStatus: key ${key} has unexpected type: ${keyType}, expected zset` + ) + return { + apiKeyId, + key, + activeCount: 0, + expiredCount: 0, + activeRequests: [], + exists: true, + invalidType: keyType + } + } + // 获取所有成员和分数 const allMembers = await client.zrange(key, 0, -1, 'WITHSCORES') @@ -2233,20 +2406,36 @@ class RedisClient { const client = this.getClientSafe() const key = `concurrency:${apiKeyId}` - // 获取清理前的状态 - const beforeCount = await client.zcard(key) + // 检查键类型 + const keyType = await client.type(key) - // 删除整个 key + let beforeCount = 0 + let isLegacy = false + + if (keyType === 'zset') { + // 正常的 zset 键,获取条目数 + beforeCount = await client.zcard(key) + } else if (keyType !== 'none') { + // 非 zset 且非空的遗留键 + isLegacy = true + logger.warn( + `⚠️ forceClearConcurrency: key ${key} has unexpected type: ${keyType}, will be deleted` + ) + } + + // 删除键(无论什么类型) await client.del(key) logger.warn( - `🧹 Force cleared concurrency for key ${apiKeyId}, removed ${beforeCount} entries` + `🧹 Force cleared concurrency for key ${apiKeyId}, removed ${beforeCount} entries${isLegacy ? ' (legacy key)' : ''}` ) return { apiKeyId, key, clearedCount: beforeCount, + type: keyType, + legacy: isLegacy, success: true } } catch (error) { @@ -2265,25 +2454,47 @@ class RedisClient { const keys = await client.keys('concurrency:*') let totalCleared = 0 + let legacyCleared = 0 const clearedKeys = [] for (const key of keys) { - const count = await client.zcard(key) - await client.del(key) - totalCleared += count - clearedKeys.push({ - key, - clearedCount: count - }) + // 跳过 queue 相关的键(它们有各自的清理逻辑) + if (key.startsWith('concurrency:queue:')) { + continue + } + + // 检查键类型 + const keyType = await client.type(key) + if (keyType === 'zset') { + const count = await client.zcard(key) + await client.del(key) + totalCleared += count + clearedKeys.push({ + key, + clearedCount: count, + type: 'zset' + }) + } else { + // 非 zset 类型的遗留键,直接删除 + await client.del(key) + legacyCleared++ + clearedKeys.push({ + key, + clearedCount: 0, + type: keyType, + legacy: true + }) + } } logger.warn( - `🧹 Force cleared all concurrency: ${keys.length} keys, ${totalCleared} total entries` + `🧹 Force cleared all concurrency: ${clearedKeys.length} keys, ${totalCleared} entries, ${legacyCleared} legacy keys` ) return { - keysCleared: keys.length, + keysCleared: clearedKeys.length, totalEntriesCleared: totalCleared, + legacyKeysCleared: legacyCleared, clearedKeys, success: true } @@ -2311,9 +2522,30 @@ class RedisClient { } let totalCleaned = 0 + let legacyCleaned = 0 const cleanedKeys = [] for (const key of keys) { + // 跳过 queue 相关的键(它们有各自的清理逻辑) + if (key.startsWith('concurrency:queue:')) { + continue + } + + // 检查键类型 + const keyType = await client.type(key) + if (keyType !== 'zset') { + // 非 zset 类型的遗留键,直接删除 + await client.del(key) + legacyCleaned++ + cleanedKeys.push({ + key, + cleanedCount: 0, + type: keyType, + legacy: true + }) + continue + } + // 只清理过期的条目 const cleaned = await client.zremrangebyscore(key, '-inf', now) if (cleaned > 0) { @@ -2332,13 +2564,14 @@ class RedisClient { } logger.info( - `🧹 Cleaned up expired concurrency: ${totalCleaned} entries from ${cleanedKeys.length} keys` + `🧹 Cleaned up expired concurrency: ${totalCleaned} entries from ${cleanedKeys.length} keys, ${legacyCleaned} legacy keys removed` ) return { keysProcessed: keys.length, keysCleaned: cleanedKeys.length, totalEntriesCleaned: totalCleaned, + legacyKeysRemoved: legacyCleaned, cleanedKeys, success: true } @@ -3157,4 +3390,249 @@ redisClient.scanConcurrencyQueueStatsKeys = async function () { } } +// ============================================================================ +// 账户测试历史相关操作 +// ============================================================================ + +const ACCOUNT_TEST_HISTORY_MAX = 5 // 保留最近5次测试记录 +const ACCOUNT_TEST_HISTORY_TTL = 86400 * 30 // 30天过期 +const ACCOUNT_TEST_CONFIG_TTL = 86400 * 365 // 测试配置保留1年(用户通常长期使用) + +/** + * 保存账户测试结果 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 (claude/gemini/openai等) + * @param {Object} testResult - 测试结果对象 + * @param {boolean} testResult.success - 是否成功 + * @param {string} testResult.message - 测试消息/响应 + * @param {number} testResult.latencyMs - 延迟毫秒数 + * @param {string} testResult.error - 错误信息(如有) + * @param {string} testResult.timestamp - 测试时间戳 + */ +redisClient.saveAccountTestResult = async function (accountId, platform, testResult) { + const key = `account:test_history:${platform}:${accountId}` + try { + const record = JSON.stringify({ + ...testResult, + timestamp: testResult.timestamp || new Date().toISOString() + }) + + // 使用 LPUSH + LTRIM 保持最近5条记录 + const client = this.getClientSafe() + await client.lpush(key, record) + await client.ltrim(key, 0, ACCOUNT_TEST_HISTORY_MAX - 1) + await client.expire(key, ACCOUNT_TEST_HISTORY_TTL) + + logger.debug(`📝 Saved test result for ${platform} account ${accountId}`) + } catch (error) { + logger.error(`Failed to save test result for ${accountId}:`, error) + } +} + +/** + * 获取账户测试历史 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 测试历史记录数组(最新在前) + */ +redisClient.getAccountTestHistory = async function (accountId, platform) { + const key = `account:test_history:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const records = await client.lrange(key, 0, -1) + return records.map((r) => JSON.parse(r)) + } catch (error) { + logger.error(`Failed to get test history for ${accountId}:`, error) + return [] + } +} + +/** + * 获取账户最新测试结果 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 最新测试结果 + */ +redisClient.getAccountLatestTestResult = async function (accountId, platform) { + const key = `account:test_history:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const record = await client.lindex(key, 0) + return record ? JSON.parse(record) : null + } catch (error) { + logger.error(`Failed to get latest test result for ${accountId}:`, error) + return null + } +} + +/** + * 批量获取多个账户的测试历史 + * @param {Array<{accountId: string, platform: string}>} accounts - 账户列表 + * @returns {Promise} 以 accountId 为 key 的测试历史映射 + */ +redisClient.getAccountsTestHistory = async function (accounts) { + const result = {} + try { + const client = this.getClientSafe() + const pipeline = client.pipeline() + + for (const { accountId, platform } of accounts) { + const key = `account:test_history:${platform}:${accountId}` + pipeline.lrange(key, 0, -1) + } + + const responses = await pipeline.exec() + + accounts.forEach(({ accountId }, index) => { + const [err, records] = responses[index] + if (!err && records) { + result[accountId] = records.map((r) => JSON.parse(r)) + } else { + result[accountId] = [] + } + }) + } catch (error) { + logger.error('Failed to get batch test history:', error) + } + return result +} + +/** + * 保存定时测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {Object} config - 配置对象 + * @param {boolean} config.enabled - 是否启用定时测试 + * @param {string} config.cronExpression - Cron 表达式 (如 "0 8 * * *" 表示每天8点) + * @param {string} config.model - 测试使用的模型 + */ +redisClient.saveAccountTestConfig = async function (accountId, platform, testConfig) { + const key = `account:test_config:${platform}:${accountId}` + try { + const client = this.getClientSafe() + await client.hset(key, { + enabled: testConfig.enabled ? 'true' : 'false', + cronExpression: testConfig.cronExpression || '0 8 * * *', // 默认每天早上8点 + model: testConfig.model || 'claude-sonnet-4-5-20250929', // 默认模型 + updatedAt: new Date().toISOString() + }) + // 设置过期时间(1年) + await client.expire(key, ACCOUNT_TEST_CONFIG_TTL) + } catch (error) { + logger.error(`Failed to save test config for ${accountId}:`, error) + } +} + +/** + * 获取定时测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 配置对象 + */ +redisClient.getAccountTestConfig = async function (accountId, platform) { + const key = `account:test_config:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const testConfig = await client.hgetall(key) + if (!testConfig || Object.keys(testConfig).length === 0) { + return null + } + // 向后兼容:如果存在旧的 testHour 字段,转换为 cron 表达式 + let { cronExpression } = testConfig + if (!cronExpression && testConfig.testHour) { + const hour = parseInt(testConfig.testHour, 10) + cronExpression = `0 ${hour} * * *` + } + return { + enabled: testConfig.enabled === 'true', + cronExpression: cronExpression || '0 8 * * *', + model: testConfig.model || 'claude-sonnet-4-5-20250929', + updatedAt: testConfig.updatedAt + } + } catch (error) { + logger.error(`Failed to get test config for ${accountId}:`, error) + return null + } +} + +/** + * 获取所有启用定时测试的账户 + * @param {string} platform - 平台类型 + * @returns {Promise} 账户ID列表及 cron 配置 + */ +redisClient.getEnabledTestAccounts = async function (platform) { + const accountIds = [] + let cursor = '0' + + try { + const client = this.getClientSafe() + do { + const [newCursor, keys] = await client.scan( + cursor, + 'MATCH', + `account:test_config:${platform}:*`, + 'COUNT', + 100 + ) + cursor = newCursor + + for (const key of keys) { + const testConfig = await client.hgetall(key) + if (testConfig && testConfig.enabled === 'true') { + const accountId = key.replace(`account:test_config:${platform}:`, '') + // 向后兼容:如果存在旧的 testHour 字段,转换为 cron 表达式 + let { cronExpression } = testConfig + if (!cronExpression && testConfig.testHour) { + const hour = parseInt(testConfig.testHour, 10) + cronExpression = `0 ${hour} * * *` + } + accountIds.push({ + accountId, + cronExpression: cronExpression || '0 8 * * *', + model: testConfig.model || 'claude-sonnet-4-5-20250929' + }) + } + } + } while (cursor !== '0') + + return accountIds + } catch (error) { + logger.error(`Failed to get enabled test accounts for ${platform}:`, error) + return [] + } +} + +/** + * 保存账户上次测试时间(用于调度器判断是否需要测试) + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + */ +redisClient.setAccountLastTestTime = async function (accountId, platform) { + const key = `account:last_test:${platform}:${accountId}` + try { + const client = this.getClientSafe() + await client.set(key, Date.now().toString(), 'EX', 86400 * 7) // 7天过期 + } catch (error) { + logger.error(`Failed to set last test time for ${accountId}:`, error) + } +} + +/** + * 获取账户上次测试时间 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 上次测试时间戳 + */ +redisClient.getAccountLastTestTime = async function (accountId, platform) { + const key = `account:last_test:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const timestamp = await client.get(key) + return timestamp ? parseInt(timestamp, 10) : null + } catch (error) { + logger.error(`Failed to get last test time for ${accountId}:`, error) + return null + } +} + module.exports = redisClient diff --git a/src/routes/admin/accountBalance.js b/src/routes/admin/accountBalance.js new file mode 100644 index 00000000..7f1d18db --- /dev/null +++ b/src/routes/admin/accountBalance.js @@ -0,0 +1,214 @@ +const express = require('express') +const { authenticateAdmin } = require('../../middleware/auth') +const logger = require('../../utils/logger') +const accountBalanceService = require('../../services/accountBalanceService') +const balanceScriptService = require('../../services/balanceScriptService') +const { isBalanceScriptEnabled } = require('../../utils/featureFlags') + +const router = express.Router() + +const ensureValidPlatform = (rawPlatform) => { + const normalized = accountBalanceService.normalizePlatform(rawPlatform) + if (!normalized) { + return { ok: false, status: 400, error: '缺少 platform 参数' } + } + + const supported = accountBalanceService.getSupportedPlatforms() + if (!supported.includes(normalized)) { + return { ok: false, status: 400, error: `不支持的平台: ${normalized}` } + } + + return { ok: true, platform: normalized } +} + +// 1) 获取账户余额(默认本地统计优先,可选触发 Provider) +// GET /admin/accounts/:accountId/balance?platform=xxx&queryApi=false +router.get('/accounts/:accountId/balance', authenticateAdmin, async (req, res) => { + try { + const { accountId } = req.params + const { platform, queryApi } = req.query + + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + const balance = await accountBalanceService.getAccountBalance(accountId, valid.platform, { + queryApi + }) + + if (!balance) { + return res.status(404).json({ success: false, error: 'Account not found' }) + } + + return res.json(balance) + } catch (error) { + logger.error('获取账户余额失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +// 2) 强制刷新账户余额(强制触发查询:优先脚本;Provider 仅为降级) +// POST /admin/accounts/:accountId/balance/refresh +// Body: { platform: 'xxx' } +router.post('/accounts/:accountId/balance/refresh', authenticateAdmin, async (req, res) => { + try { + const { accountId } = req.params + const { platform } = req.body || {} + + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + logger.info(`手动刷新余额: ${valid.platform}:${accountId}`) + + const balance = await accountBalanceService.refreshAccountBalance(accountId, valid.platform) + if (!balance) { + return res.status(404).json({ success: false, error: 'Account not found' }) + } + + return res.json(balance) + } catch (error) { + logger.error('刷新账户余额失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +// 3) 批量获取平台所有账户余额 +// GET /admin/accounts/balance/platform/:platform?queryApi=false +router.get('/accounts/balance/platform/:platform', authenticateAdmin, async (req, res) => { + try { + const { platform } = req.params + const { queryApi } = req.query + + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + const balances = await accountBalanceService.getAllAccountsBalance(valid.platform, { queryApi }) + + return res.json({ success: true, data: balances }) + } catch (error) { + logger.error('批量获取余额失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +// 4) 获取余额汇总(Dashboard 用) +// GET /admin/accounts/balance/summary +router.get('/accounts/balance/summary', authenticateAdmin, async (req, res) => { + try { + const summary = await accountBalanceService.getBalanceSummary() + return res.json({ success: true, data: summary }) + } catch (error) { + logger.error('获取余额汇总失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +// 5) 清除缓存 +// DELETE /admin/accounts/:accountId/balance/cache?platform=xxx +router.delete('/accounts/:accountId/balance/cache', authenticateAdmin, async (req, res) => { + try { + const { accountId } = req.params + const { platform } = req.query + + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + await accountBalanceService.clearCache(accountId, valid.platform) + + return res.json({ success: true, message: '缓存已清除' }) + } catch (error) { + logger.error('清除缓存失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +// 6) 获取/保存/测试余额脚本配置(单账户) +router.get('/accounts/:accountId/balance/script', authenticateAdmin, async (req, res) => { + try { + const { accountId } = req.params + const { platform } = req.query + + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + const config = await accountBalanceService.redis.getBalanceScriptConfig( + valid.platform, + accountId + ) + return res.json({ success: true, data: config || null }) + } catch (error) { + logger.error('获取余额脚本配置失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +router.put('/accounts/:accountId/balance/script', authenticateAdmin, async (req, res) => { + try { + const { accountId } = req.params + const { platform } = req.query + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + const payload = req.body || {} + await accountBalanceService.redis.setBalanceScriptConfig(valid.platform, accountId, payload) + return res.json({ success: true, data: payload }) + } catch (error) { + logger.error('保存余额脚本配置失败', error) + return res.status(500).json({ success: false, error: error.message }) + } +}) + +router.post('/accounts/:accountId/balance/script/test', authenticateAdmin, async (req, res) => { + try { + const { accountId } = req.params + const { platform } = req.query + const valid = ensureValidPlatform(platform) + if (!valid.ok) { + return res.status(valid.status).json({ success: false, error: valid.error }) + } + + if (!isBalanceScriptEnabled()) { + return res.status(403).json({ + success: false, + error: '余额脚本功能已禁用(可通过 BALANCE_SCRIPT_ENABLED=true 启用)' + }) + } + + const payload = req.body || {} + const { scriptBody } = payload + if (!scriptBody) { + return res.status(400).json({ success: false, error: '脚本内容不能为空' }) + } + + const result = await balanceScriptService.execute({ + scriptBody, + timeoutSeconds: payload.timeoutSeconds || 10, + variables: { + baseUrl: payload.baseUrl || '', + apiKey: payload.apiKey || '', + token: payload.token || '', + accountId, + platform: valid.platform, + extra: payload.extra || '' + } + }) + + return res.json({ success: true, data: result }) + } catch (error) { + logger.error('测试余额脚本失败', error) + return res.status(400).json({ success: false, error: error.message }) + } +}) + +module.exports = router diff --git a/src/routes/admin/apiKeys.js b/src/routes/admin/apiKeys.js index d88444bd..5994f56d 100644 --- a/src/routes/admin/apiKeys.js +++ b/src/routes/admin/apiKeys.js @@ -8,6 +8,43 @@ const config = require('../../../config/config') const router = express.Router() +// 有效的权限值列表 +const VALID_PERMISSIONS = ['claude', 'gemini', 'openai', 'droid'] + +/** + * 验证权限数组格式 + * @param {any} permissions - 权限值(可以是数组或其他) + * @returns {string|null} - 返回错误消息,null 表示验证通过 + */ +function validatePermissions(permissions) { + // 空值或未定义表示全部服务 + if (permissions === undefined || permissions === null || permissions === '') { + return null + } + // 兼容旧格式字符串 + if (typeof permissions === 'string') { + if (permissions === 'all' || VALID_PERMISSIONS.includes(permissions)) { + return null + } + return `Invalid permissions value. Must be an array of: ${VALID_PERMISSIONS.join(', ')}` + } + // 新格式数组 + if (Array.isArray(permissions)) { + // 空数组表示全部服务 + if (permissions.length === 0) { + return null + } + // 验证数组中的每个值 + for (const perm of permissions) { + if (!VALID_PERMISSIONS.includes(perm)) { + return `Invalid permission value "${perm}". Valid values are: ${VALID_PERMISSIONS.join(', ')}` + } + } + return null + } + return `Permissions must be an array. Valid values are: ${VALID_PERMISSIONS.join(', ')}` +} + // 👥 用户管理 (用于API Key分配) // 获取所有用户列表(用于API Key分配) @@ -1382,16 +1419,10 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => { } } - // 验证服务权限字段 - if ( - permissions !== undefined && - permissions !== null && - permissions !== '' && - !['claude', 'gemini', 'openai', 'droid', 'all'].includes(permissions) - ) { - return res.status(400).json({ - error: 'Invalid permissions value. Must be claude, gemini, openai, droid, or all' - }) + // 验证服务权限字段(支持数组格式) + const permissionsError = validatePermissions(permissions) + if (permissionsError) { + return res.status(400).json({ error: permissionsError }) } const newKey = await apiKeyService.generateApiKey({ @@ -1481,15 +1512,10 @@ router.post('/api-keys/batch', authenticateAdmin, async (req, res) => { .json({ error: 'Base name must be less than 90 characters to allow for numbering' }) } - if ( - permissions !== undefined && - permissions !== null && - permissions !== '' && - !['claude', 'gemini', 'openai', 'droid', 'all'].includes(permissions) - ) { - return res.status(400).json({ - error: 'Invalid permissions value. Must be claude, gemini, openai, droid, or all' - }) + // 验证服务权限字段(支持数组格式) + const batchPermissionsError = validatePermissions(permissions) + if (batchPermissionsError) { + return res.status(400).json({ error: batchPermissionsError }) } // 生成批量API Keys @@ -1592,13 +1618,12 @@ router.put('/api-keys/batch', authenticateAdmin, async (req, res) => { }) } - if ( - updates.permissions !== undefined && - !['claude', 'gemini', 'openai', 'droid', 'all'].includes(updates.permissions) - ) { - return res.status(400).json({ - error: 'Invalid permissions value. Must be claude, gemini, openai, droid, or all' - }) + // 验证服务权限字段(支持数组格式) + if (updates.permissions !== undefined) { + const updatePermissionsError = validatePermissions(updates.permissions) + if (updatePermissionsError) { + return res.status(400).json({ error: updatePermissionsError }) + } } logger.info( @@ -1873,11 +1898,10 @@ router.put('/api-keys/:keyId', authenticateAdmin, async (req, res) => { } if (permissions !== undefined) { - // 验证权限值 - if (!['claude', 'gemini', 'openai', 'droid', 'all'].includes(permissions)) { - return res.status(400).json({ - error: 'Invalid permissions value. Must be claude, gemini, openai, droid, or all' - }) + // 验证服务权限字段(支持数组格式) + const singlePermissionsError = validatePermissions(permissions) + if (singlePermissionsError) { + return res.status(400).json({ error: singlePermissionsError }) } updates.permissions = permissions } diff --git a/src/routes/admin/balanceScripts.js b/src/routes/admin/balanceScripts.js new file mode 100644 index 00000000..ef7ffa01 --- /dev/null +++ b/src/routes/admin/balanceScripts.js @@ -0,0 +1,41 @@ +const express = require('express') +const { authenticateAdmin } = require('../../middleware/auth') +const balanceScriptService = require('../../services/balanceScriptService') +const router = express.Router() + +// 获取全部脚本配置列表 +router.get('/balance-scripts', authenticateAdmin, (req, res) => { + const items = balanceScriptService.listConfigs() + return res.json({ success: true, data: items }) +}) + +// 获取单个脚本配置 +router.get('/balance-scripts/:name', authenticateAdmin, (req, res) => { + const { name } = req.params + const config = balanceScriptService.getConfig(name || 'default') + return res.json({ success: true, data: config }) +}) + +// 保存脚本配置 +router.put('/balance-scripts/:name', authenticateAdmin, (req, res) => { + try { + const { name } = req.params + const saved = balanceScriptService.saveConfig(name || 'default', req.body || {}) + return res.json({ success: true, data: saved }) + } catch (error) { + return res.status(400).json({ success: false, error: error.message }) + } +}) + +// 测试脚本(不落库) +router.post('/balance-scripts/:name/test', authenticateAdmin, async (req, res) => { + try { + const { name } = req.params + const result = await balanceScriptService.testScript(name || 'default', req.body || {}) + return res.json({ success: true, data: result }) + } catch (error) { + return res.status(400).json({ success: false, error: error.message }) + } +}) + +module.exports = router diff --git a/src/routes/admin/claudeAccounts.js b/src/routes/admin/claudeAccounts.js index 3443d394..d079e346 100644 --- a/src/routes/admin/claudeAccounts.js +++ b/src/routes/admin/claudeAccounts.js @@ -9,6 +9,7 @@ const router = express.Router() const claudeAccountService = require('../../services/claudeAccountService') const claudeRelayService = require('../../services/claudeRelayService') const accountGroupService = require('../../services/accountGroupService') +const accountTestSchedulerService = require('../../services/accountTestSchedulerService') const apiKeyService = require('../../services/apiKeyService') const redis = require('../../models/redis') const { authenticateAdmin } = require('../../middleware/auth') @@ -583,7 +584,9 @@ router.post('/claude-accounts', authenticateAdmin, async (req, res) => { useUnifiedClientId, unifiedClientId, expiresAt, - extInfo + extInfo, + maxConcurrency, + interceptWarmup } = req.body if (!name) { @@ -628,7 +631,9 @@ router.post('/claude-accounts', authenticateAdmin, async (req, res) => { useUnifiedClientId: useUnifiedClientId === true, // 默认为false unifiedClientId: unifiedClientId || '', // 统一的客户端标识 expiresAt: expiresAt || null, // 账户订阅到期时间 - extInfo: extInfo || null + extInfo: extInfo || null, + maxConcurrency: maxConcurrency || 0, // 账户级串行队列:0=使用全局配置,>0=强制启用 + interceptWarmup: interceptWarmup === true // 拦截预热请求:默认为false }) // 如果是分组类型,将账户添加到分组 @@ -903,4 +908,219 @@ router.post('/claude-accounts/:accountId/test', authenticateAdmin, async (req, r } }) +// ============================================================================ +// 账户定时测试相关端点 +// ============================================================================ + +// 获取账户测试历史 +router.get('/claude-accounts/:accountId/test-history', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + + try { + const history = await redis.getAccountTestHistory(accountId, 'claude') + return res.json({ + success: true, + data: { + accountId, + platform: 'claude', + history + } + }) + } catch (error) { + logger.error(`❌ Failed to get test history for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to get test history', + message: error.message + }) + } +}) + +// 获取账户定时测试配置 +router.get('/claude-accounts/:accountId/test-config', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + + try { + const testConfig = await redis.getAccountTestConfig(accountId, 'claude') + return res.json({ + success: true, + data: { + accountId, + platform: 'claude', + config: testConfig || { + enabled: false, + cronExpression: '0 8 * * *', + model: 'claude-sonnet-4-5-20250929' + } + } + }) + } catch (error) { + logger.error(`❌ Failed to get test config for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to get test config', + message: error.message + }) + } +}) + +// 设置账户定时测试配置 +router.put('/claude-accounts/:accountId/test-config', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + const { enabled, cronExpression, model } = req.body + + try { + // 验证 enabled 参数 + if (typeof enabled !== 'boolean') { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'enabled must be a boolean' + }) + } + + // 验证 cronExpression 参数 + if (!cronExpression || typeof cronExpression !== 'string') { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'cronExpression is required and must be a string' + }) + } + + // 限制 cronExpression 长度防止 DoS + const MAX_CRON_LENGTH = 100 + if (cronExpression.length > MAX_CRON_LENGTH) { + return res.status(400).json({ + error: 'Invalid parameter', + message: `cronExpression too long (max ${MAX_CRON_LENGTH} characters)` + }) + } + + // 使用 service 的方法验证 cron 表达式 + if (!accountTestSchedulerService.validateCronExpression(cronExpression)) { + return res.status(400).json({ + error: 'Invalid parameter', + message: `Invalid cron expression: ${cronExpression}. Format: "minute hour day month weekday" (e.g., "0 8 * * *" for daily at 8:00)` + }) + } + + // 验证模型参数 + const testModel = model || 'claude-sonnet-4-5-20250929' + if (typeof testModel !== 'string' || testModel.length > 256) { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'model must be a valid string (max 256 characters)' + }) + } + + // 检查账户是否存在 + const account = await claudeAccountService.getAccount(accountId) + if (!account) { + return res.status(404).json({ + error: 'Account not found', + message: `Claude account ${accountId} not found` + }) + } + + // 保存配置 + await redis.saveAccountTestConfig(accountId, 'claude', { + enabled, + cronExpression, + model: testModel + }) + + logger.success( + `📝 Updated test config for Claude account ${accountId}: enabled=${enabled}, cronExpression=${cronExpression}, model=${testModel}` + ) + + return res.json({ + success: true, + message: 'Test config updated successfully', + data: { + accountId, + platform: 'claude', + config: { enabled, cronExpression, model: testModel } + } + }) + } catch (error) { + logger.error(`❌ Failed to update test config for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to update test config', + message: error.message + }) + } +}) + +// 手动触发账户测试(非流式,返回JSON结果) +router.post('/claude-accounts/:accountId/test-sync', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + + try { + // 检查账户是否存在 + const account = await claudeAccountService.getAccount(accountId) + if (!account) { + return res.status(404).json({ + error: 'Account not found', + message: `Claude account ${accountId} not found` + }) + } + + logger.info(`🧪 Manual sync test triggered for Claude account: ${accountId}`) + + // 执行测试 + const testResult = await claudeRelayService.testAccountConnectionSync(accountId) + + // 保存测试结果到历史 + await redis.saveAccountTestResult(accountId, 'claude', testResult) + await redis.setAccountLastTestTime(accountId, 'claude') + + return res.json({ + success: true, + data: { + accountId, + platform: 'claude', + result: testResult + } + }) + } catch (error) { + logger.error(`❌ Failed to run sync test for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to run test', + message: error.message + }) + } +}) + +// 批量获取多个账户的测试历史 +router.post('/claude-accounts/batch-test-history', authenticateAdmin, async (req, res) => { + const { accountIds } = req.body + + try { + if (!Array.isArray(accountIds) || accountIds.length === 0) { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'accountIds must be a non-empty array' + }) + } + + // 限制批量查询数量 + const limitedIds = accountIds.slice(0, 100) + + const accounts = limitedIds.map((accountId) => ({ + accountId, + platform: 'claude' + })) + + const historyMap = await redis.getAccountsTestHistory(accounts) + + return res.json({ + success: true, + data: historyMap + }) + } catch (error) { + logger.error('❌ Failed to get batch test history:', error) + return res.status(500).json({ + error: 'Failed to get batch test history', + message: error.message + }) + } +}) + module.exports = router diff --git a/src/routes/admin/claudeConsoleAccounts.js b/src/routes/admin/claudeConsoleAccounts.js index 311806a3..fc0fcf62 100644 --- a/src/routes/admin/claudeConsoleAccounts.js +++ b/src/routes/admin/claudeConsoleAccounts.js @@ -132,7 +132,8 @@ router.post('/claude-console-accounts', authenticateAdmin, async (req, res) => { dailyQuota, quotaResetTime, maxConcurrentTasks, - disableAutoProtection + disableAutoProtection, + interceptWarmup } = req.body if (!name || !apiUrl || !apiKey) { @@ -186,7 +187,8 @@ router.post('/claude-console-accounts', authenticateAdmin, async (req, res) => { maxConcurrentTasks !== undefined && maxConcurrentTasks !== null ? Number(maxConcurrentTasks) : 0, - disableAutoProtection: normalizedDisableAutoProtection + disableAutoProtection: normalizedDisableAutoProtection, + interceptWarmup: interceptWarmup === true || interceptWarmup === 'true' }) // 如果是分组类型,将账户添加到分组(CCR 归属 Claude 平台分组) diff --git a/src/routes/admin/index.js b/src/routes/admin/index.js index c91aa5e7..7fe901c7 100644 --- a/src/routes/admin/index.js +++ b/src/routes/admin/index.js @@ -21,9 +21,11 @@ const openaiResponsesAccountsRoutes = require('./openaiResponsesAccounts') const droidAccountsRoutes = require('./droidAccounts') const dashboardRoutes = require('./dashboard') const usageStatsRoutes = require('./usageStats') +const accountBalanceRoutes = require('./accountBalance') const systemRoutes = require('./system') const concurrencyRoutes = require('./concurrency') const claudeRelayConfigRoutes = require('./claudeRelayConfig') +const syncRoutes = require('./sync') // 挂载所有子路由 // 使用完整路径的模块(直接挂载到根路径) @@ -36,9 +38,11 @@ router.use('/', openaiResponsesAccountsRoutes) router.use('/', droidAccountsRoutes) router.use('/', dashboardRoutes) router.use('/', usageStatsRoutes) +router.use('/', accountBalanceRoutes) router.use('/', systemRoutes) router.use('/', concurrencyRoutes) router.use('/', claudeRelayConfigRoutes) +router.use('/', syncRoutes) // 使用相对路径的模块(需要指定基础路径前缀) router.use('/account-groups', accountGroupsRoutes) diff --git a/src/routes/admin/sync.js b/src/routes/admin/sync.js new file mode 100644 index 00000000..6345e810 --- /dev/null +++ b/src/routes/admin/sync.js @@ -0,0 +1,460 @@ +/** + * Admin Routes - Sync / Export (for migration) + * Exports account data (including secrets) for safe server-to-server syncing. + */ + +const express = require('express') +const router = express.Router() + +const { authenticateAdmin } = require('../../middleware/auth') +const redis = require('../../models/redis') +const claudeAccountService = require('../../services/claudeAccountService') +const claudeConsoleAccountService = require('../../services/claudeConsoleAccountService') +const openaiAccountService = require('../../services/openaiAccountService') +const openaiResponsesAccountService = require('../../services/openaiResponsesAccountService') +const logger = require('../../utils/logger') + +function toBool(value, defaultValue = false) { + if (value === undefined || value === null || value === '') { + return defaultValue + } + if (value === true || value === 'true') { + return true + } + if (value === false || value === 'false') { + return false + } + return defaultValue +} + +function normalizeProxy(proxy) { + if (!proxy || typeof proxy !== 'object') { + return null + } + + const protocol = proxy.protocol || proxy.type || proxy.scheme || '' + const host = proxy.host || '' + const port = Number(proxy.port || 0) + + if (!protocol || !host || !Number.isFinite(port) || port <= 0) { + return null + } + + return { + protocol: String(protocol), + host: String(host), + port, + username: proxy.username ? String(proxy.username) : '', + password: proxy.password ? String(proxy.password) : '' + } +} + +function buildModelMappingFromSupportedModels(supportedModels) { + if (!supportedModels) { + return null + } + + if (Array.isArray(supportedModels)) { + const mapping = {} + for (const model of supportedModels) { + if (typeof model === 'string' && model.trim()) { + mapping[model.trim()] = model.trim() + } + } + return Object.keys(mapping).length ? mapping : null + } + + if (typeof supportedModels === 'object') { + const mapping = {} + for (const [from, to] of Object.entries(supportedModels)) { + if (typeof from === 'string' && typeof to === 'string' && from.trim() && to.trim()) { + mapping[from.trim()] = to.trim() + } + } + return Object.keys(mapping).length ? mapping : null + } + + return null +} + +function safeParseJson(raw, fallback = null) { + if (!raw || typeof raw !== 'string') { + return fallback + } + try { + return JSON.parse(raw) + } catch (_) { + return fallback + } +} + +// Export accounts for migration (includes secrets). +// GET /admin/sync/export-accounts?include_secrets=true +router.get('/sync/export-accounts', authenticateAdmin, async (req, res) => { + try { + const includeSecrets = toBool(req.query.include_secrets, false) + if (!includeSecrets) { + return res.status(400).json({ + success: false, + error: 'include_secrets_required', + message: 'Set include_secrets=true to export secrets' + }) + } + + // ===== Claude official OAuth / Setup Token accounts ===== + const rawClaudeAccounts = await redis.getAllClaudeAccounts() + const claudeAccounts = rawClaudeAccounts.map((account) => { + // Backward compatible extraction: prefer individual fields, fallback to claudeAiOauth JSON blob. + let decryptedClaudeAiOauth = null + if (account.claudeAiOauth) { + try { + const raw = claudeAccountService._decryptSensitiveData(account.claudeAiOauth) + decryptedClaudeAiOauth = raw ? JSON.parse(raw) : null + } catch (_) { + decryptedClaudeAiOauth = null + } + } + + const rawScopes = + account.scopes && account.scopes.trim() + ? account.scopes + : decryptedClaudeAiOauth?.scopes + ? decryptedClaudeAiOauth.scopes.join(' ') + : '' + + const scopes = rawScopes && rawScopes.trim() ? rawScopes.trim().split(' ') : [] + const isOAuth = scopes.includes('user:profile') && scopes.includes('user:inference') + const authType = isOAuth ? 'oauth' : 'setup-token' + + const accessToken = + account.accessToken && String(account.accessToken).trim() + ? claudeAccountService._decryptSensitiveData(account.accessToken) + : decryptedClaudeAiOauth?.accessToken || '' + + const refreshToken = + account.refreshToken && String(account.refreshToken).trim() + ? claudeAccountService._decryptSensitiveData(account.refreshToken) + : decryptedClaudeAiOauth?.refreshToken || '' + + let expiresAt = null + const expiresAtMs = Number.parseInt(account.expiresAt, 10) + if (Number.isFinite(expiresAtMs) && expiresAtMs > 0) { + expiresAt = new Date(expiresAtMs).toISOString() + } else if (decryptedClaudeAiOauth?.expiresAt) { + try { + expiresAt = new Date(Number(decryptedClaudeAiOauth.expiresAt)).toISOString() + } catch (_) { + expiresAt = null + } + } + + const proxy = account.proxy ? normalizeProxy(safeParseJson(account.proxy)) : null + + // 🔧 Parse subscriptionInfo to extract org_uuid and account_uuid + let orgUuid = null + let accountUuid = null + if (account.subscriptionInfo) { + try { + const subscriptionInfo = JSON.parse(account.subscriptionInfo) + orgUuid = subscriptionInfo.organizationUuid || null + accountUuid = subscriptionInfo.accountUuid || null + } catch (_) { + // Ignore parse errors + } + } + + // 🔧 Calculate expires_in from expires_at + let expiresIn = null + if (expiresAt) { + try { + const expiresAtTime = new Date(expiresAt).getTime() + const nowTime = Date.now() + const diffSeconds = Math.floor((expiresAtTime - nowTime) / 1000) + if (diffSeconds > 0) { + expiresIn = diffSeconds + } + } catch (_) { + // Ignore calculation errors + } + } + // 🔧 Use default expires_in if calculation failed (Anthropic OAuth: 8 hours) + if (!expiresIn && isOAuth) { + expiresIn = 28800 // 8 hours + } + + const credentials = { + access_token: accessToken, + refresh_token: refreshToken || undefined, + expires_at: expiresAt || undefined, + expires_in: expiresIn || undefined, + scope: scopes.join(' ') || undefined, + token_type: 'Bearer' + } + // 🔧 Add auth info as top-level credentials fields + if (orgUuid) { + credentials.org_uuid = orgUuid + } + if (accountUuid) { + credentials.account_uuid = accountUuid + } + + // 🔧 Store complete original CRS data in extra + const extra = { + crs_account_id: account.id, + crs_kind: 'claude-account', + crs_id: account.id, + crs_name: account.name, + crs_description: account.description || '', + crs_platform: account.platform || 'claude', + crs_auth_type: authType, + crs_is_active: account.isActive === 'true', + crs_schedulable: account.schedulable !== 'false', + crs_priority: Number.parseInt(account.priority, 10) || 50, + crs_status: account.status || 'active', + crs_scopes: scopes, + crs_subscription_info: account.subscriptionInfo || undefined + } + + return { + kind: 'claude-account', + id: account.id, + name: account.name, + description: account.description || '', + platform: account.platform || 'claude', + authType, + isActive: account.isActive === 'true', + schedulable: account.schedulable !== 'false', + priority: Number.parseInt(account.priority, 10) || 50, + status: account.status || 'active', + proxy, + credentials, + extra + } + }) + + // ===== Claude Console API Key accounts ===== + const claudeConsoleSummaries = await claudeConsoleAccountService.getAllAccounts() + const claudeConsoleAccounts = [] + for (const summary of claudeConsoleSummaries) { + const full = await claudeConsoleAccountService.getAccount(summary.id) + if (!full) { + continue + } + + const proxy = normalizeProxy(full.proxy) + const modelMapping = buildModelMappingFromSupportedModels(full.supportedModels) + + const credentials = { + api_key: full.apiKey, + base_url: full.apiUrl + } + + if (modelMapping) { + credentials.model_mapping = modelMapping + } + + if (full.userAgent) { + credentials.user_agent = full.userAgent + } + + claudeConsoleAccounts.push({ + kind: 'claude-console-account', + id: full.id, + name: full.name, + description: full.description || '', + platform: full.platform || 'claude-console', + isActive: full.isActive === true, + schedulable: full.schedulable !== false, + priority: Number.parseInt(full.priority, 10) || 50, + status: full.status || 'active', + proxy, + maxConcurrentTasks: Number.parseInt(full.maxConcurrentTasks, 10) || 0, + credentials, + extra: { + crs_account_id: full.id, + crs_kind: 'claude-console-account', + crs_id: full.id, + crs_name: full.name, + crs_description: full.description || '', + crs_platform: full.platform || 'claude-console', + crs_is_active: full.isActive === true, + crs_schedulable: full.schedulable !== false, + crs_priority: Number.parseInt(full.priority, 10) || 50, + crs_status: full.status || 'active' + } + }) + } + + // ===== OpenAI OAuth accounts ===== + const openaiOAuthAccounts = [] + { + const client = redis.getClientSafe() + const openaiKeys = await client.keys('openai:account:*') + for (const key of openaiKeys) { + const id = key.split(':').slice(2).join(':') + const account = await openaiAccountService.getAccount(id) + if (!account) { + continue + } + + const accessToken = account.accessToken + ? openaiAccountService.decrypt(account.accessToken) + : '' + if (!accessToken) { + // Skip broken/legacy records without decryptable token + continue + } + + const scopes = + account.scopes && typeof account.scopes === 'string' && account.scopes.trim() + ? account.scopes.trim().split(' ') + : [] + + const proxy = normalizeProxy(account.proxy) + + // 🔧 Calculate expires_in from expires_at + let expiresIn = null + if (account.expiresAt) { + try { + const expiresAtTime = new Date(account.expiresAt).getTime() + const nowTime = Date.now() + const diffSeconds = Math.floor((expiresAtTime - nowTime) / 1000) + if (diffSeconds > 0) { + expiresIn = diffSeconds + } + } catch (_) { + // Ignore calculation errors + } + } + // 🔧 Use default expires_in if calculation failed (OpenAI OAuth: 10 days) + if (!expiresIn) { + expiresIn = 864000 // 10 days + } + + const credentials = { + access_token: accessToken, + refresh_token: account.refreshToken || undefined, + id_token: account.idToken || undefined, + expires_at: account.expiresAt || undefined, + expires_in: expiresIn || undefined, + scope: scopes.join(' ') || undefined, + token_type: 'Bearer' + } + // 🔧 Add auth info as top-level credentials fields + if (account.accountId) { + credentials.chatgpt_account_id = account.accountId + } + if (account.chatgptUserId) { + credentials.chatgpt_user_id = account.chatgptUserId + } + if (account.organizationId) { + credentials.organization_id = account.organizationId + } + + // 🔧 Store complete original CRS data in extra + const extra = { + crs_account_id: account.id, + crs_kind: 'openai-oauth-account', + crs_id: account.id, + crs_name: account.name, + crs_description: account.description || '', + crs_platform: account.platform || 'openai', + crs_is_active: account.isActive === 'true', + crs_schedulable: account.schedulable !== 'false', + crs_priority: Number.parseInt(account.priority, 10) || 50, + crs_status: account.status || 'active', + crs_scopes: scopes, + crs_email: account.email || undefined, + crs_chatgpt_account_id: account.accountId || undefined, + crs_chatgpt_user_id: account.chatgptUserId || undefined, + crs_organization_id: account.organizationId || undefined + } + + openaiOAuthAccounts.push({ + kind: 'openai-oauth-account', + id: account.id, + name: account.name, + description: account.description || '', + platform: account.platform || 'openai', + authType: 'oauth', + isActive: account.isActive === 'true', + schedulable: account.schedulable !== 'false', + priority: Number.parseInt(account.priority, 10) || 50, + status: account.status || 'active', + proxy, + credentials, + extra + }) + } + } + + // ===== OpenAI Responses API Key accounts ===== + const openaiResponsesAccounts = [] + const client = redis.getClientSafe() + const openaiResponseKeys = await client.keys('openai_responses_account:*') + for (const key of openaiResponseKeys) { + const id = key.split(':').slice(1).join(':') + const full = await openaiResponsesAccountService.getAccount(id) + if (!full) { + continue + } + + const proxy = normalizeProxy(full.proxy) + + const credentials = { + api_key: full.apiKey, + base_url: full.baseApi + } + + if (full.userAgent) { + credentials.user_agent = full.userAgent + } + + openaiResponsesAccounts.push({ + kind: 'openai-responses-account', + id: full.id, + name: full.name, + description: full.description || '', + platform: full.platform || 'openai-responses', + isActive: full.isActive === 'true', + schedulable: full.schedulable !== 'false', + priority: Number.parseInt(full.priority, 10) || 50, + status: full.status || 'active', + proxy, + credentials, + extra: { + crs_account_id: full.id, + crs_kind: 'openai-responses-account', + crs_id: full.id, + crs_name: full.name, + crs_description: full.description || '', + crs_platform: full.platform || 'openai-responses', + crs_is_active: full.isActive === 'true', + crs_schedulable: full.schedulable !== 'false', + crs_priority: Number.parseInt(full.priority, 10) || 50, + crs_status: full.status || 'active' + } + }) + } + + return res.json({ + success: true, + data: { + exportedAt: new Date().toISOString(), + claudeAccounts, + claudeConsoleAccounts, + openaiOAuthAccounts, + openaiResponsesAccounts + } + }) + } catch (error) { + logger.error('❌ Failed to export accounts for sync:', error) + return res.status(500).json({ + success: false, + error: 'export_failed', + message: error.message + }) + } +}) + +module.exports = router diff --git a/src/routes/api.js b/src/routes/api.js index adc49cae..8047a51d 100644 --- a/src/routes/api.js +++ b/src/routes/api.js @@ -12,6 +12,13 @@ const { getEffectiveModel, parseVendorPrefixedModel } = require('../utils/modelH const sessionHelper = require('../utils/sessionHelper') const { updateRateLimitCounters } = require('../utils/rateLimitHelper') const claudeRelayConfigService = require('../services/claudeRelayConfigService') +const claudeAccountService = require('../services/claudeAccountService') +const claudeConsoleAccountService = require('../services/claudeConsoleAccountService') +const { + isWarmupRequest, + buildMockWarmupResponse, + sendMockWarmupStream +} = require('../utils/warmupInterceptor') const { sanitizeUpstreamError } = require('../utils/errorSanitizer') const { dumpAnthropicMessagesRequest } = require('../utils/anthropicRequestDump') const { @@ -115,6 +122,16 @@ async function handleMessagesRequest(req, res) { try { const startTime = Date.now() + // Claude 服务权限校验,阻止未授权的 Key + if (!apiKeyService.hasPermission(req.apiKey.permissions, 'claude')) { + return res.status(403).json({ + error: { + type: 'permission_error', + message: '此 API Key 无权访问 Claude 服务' + } + }) + } + // 🔄 并发满额重试标志:最多重试一次(使用req对象存储状态) if (req._concurrencyRetryAttempted === undefined) { req._concurrencyRetryAttempted = false @@ -398,6 +415,23 @@ async function handleMessagesRequest(req, res) { } } + // 🔥 预热请求拦截检查(在转发之前) + if (accountType === 'claude-official' || accountType === 'claude-console') { + const account = + accountType === 'claude-official' + ? await claudeAccountService.getAccount(accountId) + : await claudeConsoleAccountService.getAccount(accountId) + + if (account?.interceptWarmup === 'true' && isWarmupRequest(req.body)) { + logger.api(`🔥 Warmup request intercepted for account: ${account.name} (${accountId})`) + if (isStream) { + return sendMockWarmupStream(res, req.body.model) + } else { + return res.json(buildMockWarmupResponse(req.body.model)) + } + } + } + // 根据账号类型选择对应的转发服务并调用 if (accountType === 'claude-official') { // 官方Claude账号使用原有的转发服务(会自己选择账号) @@ -897,6 +931,21 @@ async function handleMessagesRequest(req, res) { } } + // 🔥 预热请求拦截检查(非流式,在转发之前) + if (accountType === 'claude-official' || accountType === 'claude-console') { + const account = + accountType === 'claude-official' + ? await claudeAccountService.getAccount(accountId) + : await claudeConsoleAccountService.getAccount(accountId) + + if (account?.interceptWarmup === 'true' && isWarmupRequest(req.body)) { + logger.api( + `🔥 Warmup request intercepted (non-stream) for account: ${account.name} (${accountId})` + ) + return res.json(buildMockWarmupResponse(req.body.model)) + } + } + // 根据账号类型选择对应的转发服务 let response logger.debug(`[DEBUG] Request query params: ${JSON.stringify(req.query)}`) @@ -1465,9 +1514,6 @@ router.post('/v1/messages/count_tokens', authenticateApiKey, async (req, res) => const maxAttempts = 2 let attempt = 0 - // 引入 claudeConsoleAccountService 用于检查 count_tokens 可用性 - const claudeConsoleAccountService = require('../services/claudeConsoleAccountService') - const processRequest = async () => { const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey( req.apiKey, @@ -1663,5 +1709,10 @@ router.post('/v1/messages/count_tokens', authenticateApiKey, async (req, res) => } }) +// Claude Code 客户端遥测端点 - 返回成功响应避免 404 日志 +router.post('/api/event_logging/batch', (req, res) => { + res.status(200).json({ success: true }) +}) + module.exports = router module.exports.handleMessagesRequest = handleMessagesRequest diff --git a/src/routes/droidRoutes.js b/src/routes/droidRoutes.js index f8479cde..b6d9932a 100644 --- a/src/routes/droidRoutes.js +++ b/src/routes/droidRoutes.js @@ -4,12 +4,12 @@ const { authenticateApiKey } = require('../middleware/auth') const droidRelayService = require('../services/droidRelayService') const sessionHelper = require('../utils/sessionHelper') const logger = require('../utils/logger') +const apiKeyService = require('../services/apiKeyService') const router = express.Router() function hasDroidPermission(apiKeyData) { - const permissions = apiKeyData?.permissions || 'all' - return permissions === 'all' || permissions === 'droid' + return apiKeyService.hasPermission(apiKeyData?.permissions, 'droid') } /** diff --git a/src/routes/openaiGeminiRoutes.js b/src/routes/openaiGeminiRoutes.js index 458aaadb..511ca248 100644 --- a/src/routes/openaiGeminiRoutes.js +++ b/src/routes/openaiGeminiRoutes.js @@ -6,6 +6,7 @@ const geminiAccountService = require('../services/geminiAccountService') const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler') const { getAvailableModels } = require('../services/geminiRelayService') const crypto = require('crypto') +const apiKeyService = require('../services/apiKeyService') // 生成会话哈希 function generateSessionHash(req) { @@ -31,8 +32,7 @@ function ensureAntigravityProjectId(account) { // 检查 API Key 权限 function checkPermissions(apiKeyData, requiredPermission = 'gemini') { - const permissions = apiKeyData.permissions || 'all' - return permissions === 'all' || permissions === requiredPermission + return apiKeyService.hasPermission(apiKeyData?.permissions, requiredPermission) } // 转换 OpenAI 消息格式到 Gemini 格式 @@ -532,7 +532,6 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => { // 记录使用统计 if (!usageReported && totalUsage.totalTokenCount > 0) { try { - const apiKeyService = require('../services/apiKeyService') await apiKeyService.recordUsage( apiKeyData.id, totalUsage.promptTokenCount || 0, @@ -634,7 +633,6 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => { // 记录使用统计 if (openaiResponse.usage) { try { - const apiKeyService = require('../services/apiKeyService') await apiKeyService.recordUsage( apiKeyData.id, openaiResponse.usage.prompt_tokens || 0, diff --git a/src/routes/openaiRoutes.js b/src/routes/openaiRoutes.js index 7faf9e87..7f1b04f1 100644 --- a/src/routes/openaiRoutes.js +++ b/src/routes/openaiRoutes.js @@ -20,8 +20,7 @@ function createProxyAgent(proxy) { // 检查 API Key 是否具备 OpenAI 权限 function checkOpenAIPermissions(apiKeyData) { - const permissions = apiKeyData?.permissions || 'all' - return permissions === 'all' || permissions === 'openai' + return apiKeyService.hasPermission(apiKeyData?.permissions, 'openai') } function normalizeHeaders(headers = {}) { diff --git a/src/routes/unified.js b/src/routes/unified.js index a8a8e69d..57c4fe80 100644 --- a/src/routes/unified.js +++ b/src/routes/unified.js @@ -8,6 +8,7 @@ const { handleStreamGenerateContent: geminiHandleStreamGenerateContent } = require('../handlers/geminiHandlers') const openaiRoutes = require('./openaiRoutes') +const apiKeyService = require('../services/apiKeyService') const router = express.Router() @@ -73,7 +74,7 @@ async function routeToBackend(req, res, requestedModel) { return await openaiRoutes.handleResponses(req, res) } else if (backend === 'gemini') { // Gemini 后端 - if (permissions !== 'all' && permissions !== 'gemini') { + if (!apiKeyService.hasPermission(permissions, 'gemini')) { return res.status(403).json({ error: { message: 'This API key does not have permission to access Gemini', diff --git a/src/routes/web.js b/src/routes/web.js index 8bbdd435..ed3bfa57 100644 --- a/src/routes/web.js +++ b/src/routes/web.js @@ -164,13 +164,27 @@ router.post('/auth/change-password', async (req, res) => { // 获取当前会话 const sessionData = await redis.getSession(token) - if (!sessionData) { + + // 🔒 安全修复:检查空对象 + if (!sessionData || Object.keys(sessionData).length === 0) { return res.status(401).json({ error: 'Invalid token', message: 'Session expired or invalid' }) } + // 🔒 安全修复:验证会话完整性 + if (!sessionData.username || !sessionData.loginTime) { + logger.security( + `🔒 Invalid session structure in /auth/change-password from ${req.ip || 'unknown'}` + ) + await redis.deleteSession(token) + return res.status(401).json({ + error: 'Invalid session', + message: 'Session data corrupted or incomplete' + }) + } + // 获取当前管理员信息 const adminData = await redis.getSession('admin_credentials') if (!adminData) { @@ -269,13 +283,25 @@ router.get('/auth/user', async (req, res) => { // 获取当前会话 const sessionData = await redis.getSession(token) - if (!sessionData) { + + // 🔒 安全修复:检查空对象 + if (!sessionData || Object.keys(sessionData).length === 0) { return res.status(401).json({ error: 'Invalid token', message: 'Session expired or invalid' }) } + // 🔒 安全修复:验证会话完整性 + if (!sessionData.username || !sessionData.loginTime) { + logger.security(`🔒 Invalid session structure in /auth/user from ${req.ip || 'unknown'}`) + await redis.deleteSession(token) + return res.status(401).json({ + error: 'Invalid session', + message: 'Session data corrupted or incomplete' + }) + } + // 获取管理员信息 const adminData = await redis.getSession('admin_credentials') if (!adminData) { @@ -316,13 +342,24 @@ router.post('/auth/refresh', async (req, res) => { const sessionData = await redis.getSession(token) - if (!sessionData) { + // 🔒 安全修复:检查空对象(hgetall 对不存在的 key 返回 {}) + if (!sessionData || Object.keys(sessionData).length === 0) { return res.status(401).json({ error: 'Invalid token', message: 'Session expired or invalid' }) } + // 🔒 安全修复:验证会话完整性(必须有 username 和 loginTime) + if (!sessionData.username || !sessionData.loginTime) { + logger.security(`🔒 Invalid session structure detected from ${req.ip || 'unknown'}`) + await redis.deleteSession(token) // 清理无效/伪造的会话 + return res.status(401).json({ + error: 'Invalid session', + message: 'Session data corrupted or incomplete' + }) + } + // 更新最后活动时间 sessionData.lastActivity = new Date().toISOString() await redis.setSession(token, sessionData, config.security.adminSessionTimeout) diff --git a/src/services/accountBalanceService.js b/src/services/accountBalanceService.js new file mode 100644 index 00000000..3265c4b8 --- /dev/null +++ b/src/services/accountBalanceService.js @@ -0,0 +1,748 @@ +const redis = require('../models/redis') +const balanceScriptService = require('./balanceScriptService') +const logger = require('../utils/logger') +const CostCalculator = require('../utils/costCalculator') +const { isBalanceScriptEnabled } = require('../utils/featureFlags') + +class AccountBalanceService { + constructor(options = {}) { + this.redis = options.redis || redis + this.logger = options.logger || logger + + this.providers = new Map() + + this.CACHE_TTL_SECONDS = 3600 + this.LOCAL_TTL_SECONDS = 300 + + this.LOW_BALANCE_THRESHOLD = 10 + this.HIGH_USAGE_THRESHOLD_PERCENT = 90 + this.DEFAULT_CONCURRENCY = 10 + } + + getSupportedPlatforms() { + return [ + 'claude', + 'claude-console', + 'gemini', + 'gemini-api', + 'openai', + 'openai-responses', + 'azure_openai', + 'bedrock', + 'droid', + 'ccr' + ] + } + + normalizePlatform(platform) { + if (!platform) { + return null + } + + const value = String(platform).trim().toLowerCase() + + // 兼容实施文档与历史命名 + if (value === 'claude-official') { + return 'claude' + } + if (value === 'azure-openai') { + return 'azure_openai' + } + + // 保持前端平台键一致 + return value + } + + registerProvider(platform, provider) { + const normalized = this.normalizePlatform(platform) + if (!normalized) { + throw new Error('registerProvider: 缺少 platform') + } + if (!provider || typeof provider.queryBalance !== 'function') { + throw new Error(`registerProvider: Provider 无效 (${normalized})`) + } + this.providers.set(normalized, provider) + } + + async getAccountBalance(accountId, platform, options = {}) { + const normalizedPlatform = this.normalizePlatform(platform) + const account = await this.getAccount(accountId, normalizedPlatform) + if (!account) { + return null + } + return await this._getAccountBalanceForAccount(account, normalizedPlatform, options) + } + + async refreshAccountBalance(accountId, platform) { + const normalizedPlatform = this.normalizePlatform(platform) + const account = await this.getAccount(accountId, normalizedPlatform) + if (!account) { + return null + } + + return await this._getAccountBalanceForAccount(account, normalizedPlatform, { + queryApi: true, + useCache: false + }) + } + + async getAllAccountsBalance(platform, options = {}) { + const normalizedPlatform = this.normalizePlatform(platform) + const accounts = await this.getAllAccountsByPlatform(normalizedPlatform) + const queryApi = this._parseBoolean(options.queryApi) || false + const useCache = options.useCache !== false + + const results = await this._mapWithConcurrency( + accounts, + this.DEFAULT_CONCURRENCY, + async (acc) => { + try { + const balance = await this._getAccountBalanceForAccount(acc, normalizedPlatform, { + queryApi, + useCache + }) + return { ...balance, name: acc.name || '' } + } catch (error) { + this.logger.error(`批量获取余额失败: ${normalizedPlatform}:${acc?.id}`, error) + return { + success: true, + data: { + accountId: acc?.id, + platform: normalizedPlatform, + balance: null, + quota: null, + statistics: {}, + source: 'local', + lastRefreshAt: new Date().toISOString(), + cacheExpiresAt: null, + status: 'error', + error: error.message || '批量查询失败' + }, + name: acc?.name || '' + } + } + } + ) + + return results + } + + async getBalanceSummary() { + const platforms = this.getSupportedPlatforms() + + const summary = { + totalBalance: 0, + totalCost: 0, + lowBalanceCount: 0, + platforms: {} + } + + for (const platform of platforms) { + const accounts = await this.getAllAccountsByPlatform(platform) + const platformData = { + count: accounts.length, + totalBalance: 0, + totalCost: 0, + lowBalanceCount: 0, + accounts: [] + } + + const balances = await this._mapWithConcurrency( + accounts, + this.DEFAULT_CONCURRENCY, + async (acc) => { + const balance = await this._getAccountBalanceForAccount(acc, platform, { + queryApi: false, + useCache: true + }) + return { ...balance, name: acc.name || '' } + } + ) + + for (const item of balances) { + platformData.accounts.push(item) + + const amount = item?.data?.balance?.amount + const percentage = item?.data?.quota?.percentage + const totalCost = Number(item?.data?.statistics?.totalCost || 0) + + const hasAmount = typeof amount === 'number' && Number.isFinite(amount) + const isLowBalance = hasAmount && amount < this.LOW_BALANCE_THRESHOLD + const isHighUsage = + typeof percentage === 'number' && + Number.isFinite(percentage) && + percentage > this.HIGH_USAGE_THRESHOLD_PERCENT + + if (hasAmount) { + platformData.totalBalance += amount + } + + if (isLowBalance || isHighUsage) { + platformData.lowBalanceCount += 1 + summary.lowBalanceCount += 1 + } + + platformData.totalCost += totalCost + } + + summary.platforms[platform] = platformData + summary.totalBalance += platformData.totalBalance + summary.totalCost += platformData.totalCost + } + + return summary + } + + async clearCache(accountId, platform) { + const normalizedPlatform = this.normalizePlatform(platform) + if (!normalizedPlatform) { + throw new Error('缺少 platform 参数') + } + + await this.redis.deleteAccountBalance(normalizedPlatform, accountId) + this.logger.info(`余额缓存已清除: ${normalizedPlatform}:${accountId}`) + } + + async getAccount(accountId, platform) { + if (!accountId || !platform) { + return null + } + + const serviceMap = { + claude: require('./claudeAccountService'), + 'claude-console': require('./claudeConsoleAccountService'), + gemini: require('./geminiAccountService'), + 'gemini-api': require('./geminiApiAccountService'), + openai: require('./openaiAccountService'), + 'openai-responses': require('./openaiResponsesAccountService'), + azure_openai: require('./azureOpenaiAccountService'), + bedrock: require('./bedrockAccountService'), + droid: require('./droidAccountService'), + ccr: require('./ccrAccountService') + } + + const service = serviceMap[platform] + if (!service || typeof service.getAccount !== 'function') { + return null + } + + return await service.getAccount(accountId) + } + + async getAllAccountsByPlatform(platform) { + if (!platform) { + return [] + } + + const serviceMap = { + claude: require('./claudeAccountService'), + 'claude-console': require('./claudeConsoleAccountService'), + gemini: require('./geminiAccountService'), + 'gemini-api': require('./geminiApiAccountService'), + openai: require('./openaiAccountService'), + 'openai-responses': require('./openaiResponsesAccountService'), + azure_openai: require('./azureOpenaiAccountService'), + bedrock: require('./bedrockAccountService'), + droid: require('./droidAccountService'), + ccr: require('./ccrAccountService') + } + + const service = serviceMap[platform] + if (!service) { + return [] + } + + // Bedrock 特殊:返回 { success, data } + if (platform === 'bedrock' && typeof service.getAllAccounts === 'function') { + const result = await service.getAllAccounts() + return result?.success ? result.data || [] : [] + } + + if (platform === 'openai-responses') { + return await service.getAllAccounts(true) + } + + if (typeof service.getAllAccounts !== 'function') { + return [] + } + + return await service.getAllAccounts() + } + + async _getAccountBalanceForAccount(account, platform, options = {}) { + const queryApi = this._parseBoolean(options.queryApi) || false + const useCache = options.useCache !== false + + const accountId = account?.id + if (!accountId) { + throw new Error('账户缺少 id') + } + + // 余额脚本配置状态(用于前端控制“刷新余额”按钮) + let scriptConfig = null + let scriptConfigured = false + if (typeof this.redis?.getBalanceScriptConfig === 'function') { + scriptConfig = await this.redis.getBalanceScriptConfig(platform, accountId) + scriptConfigured = !!( + scriptConfig && + scriptConfig.scriptBody && + String(scriptConfig.scriptBody).trim().length > 0 + ) + } + const scriptEnabled = isBalanceScriptEnabled() + const scriptMeta = { scriptEnabled, scriptConfigured } + + const localBalance = await this._getBalanceFromLocal(accountId, platform) + const localStatistics = localBalance.statistics || {} + + const quotaFromLocal = this._buildQuotaFromLocal(account, localStatistics) + + // 非强制查询:优先读缓存 + if (!queryApi) { + if (useCache) { + const cached = await this.redis.getAccountBalance(platform, accountId) + if (cached && cached.status === 'success') { + return this._buildResponse( + { + status: cached.status, + errorMessage: cached.errorMessage, + balance: quotaFromLocal.balance ?? cached.balance, + currency: quotaFromLocal.currency || cached.currency || 'USD', + quota: quotaFromLocal.quota || cached.quota || null, + statistics: localStatistics, + lastRefreshAt: cached.lastRefreshAt + }, + accountId, + platform, + 'cache', + cached.ttlSeconds, + scriptMeta + ) + } + } + + return this._buildResponse( + { + status: 'success', + errorMessage: null, + balance: quotaFromLocal.balance, + currency: quotaFromLocal.currency || 'USD', + quota: quotaFromLocal.quota, + statistics: localStatistics, + lastRefreshAt: localBalance.lastCalculated + }, + accountId, + platform, + 'local', + null, + scriptMeta + ) + } + + // 强制查询:优先脚本(如启用且已配置),否则调用 Provider;失败自动降级到本地统计 + let providerResult + + if (scriptEnabled && scriptConfigured) { + providerResult = await this._getBalanceFromScript(scriptConfig, accountId, platform) + } else { + const provider = this.providers.get(platform) + if (!provider) { + return this._buildResponse( + { + status: 'error', + errorMessage: `不支持的平台: ${platform}`, + balance: quotaFromLocal.balance, + currency: quotaFromLocal.currency || 'USD', + quota: quotaFromLocal.quota, + statistics: localStatistics, + lastRefreshAt: new Date().toISOString() + }, + accountId, + platform, + 'local', + null, + scriptMeta + ) + } + providerResult = await this._getBalanceFromProvider(provider, account) + } + + const isRemoteSuccess = + providerResult.status === 'success' && ['api', 'script'].includes(providerResult.queryMethod) + + // 仅缓存“真实远程查询成功”的结果,避免把字段/本地降级结果当作 API 结果缓存 1h + if (isRemoteSuccess) { + await this.redis.setAccountBalance( + platform, + accountId, + providerResult, + this.CACHE_TTL_SECONDS + ) + } + + const source = isRemoteSuccess ? 'api' : 'local' + + return this._buildResponse( + { + status: providerResult.status, + errorMessage: providerResult.errorMessage, + balance: quotaFromLocal.balance ?? providerResult.balance, + currency: quotaFromLocal.currency || providerResult.currency || 'USD', + quota: quotaFromLocal.quota || providerResult.quota || null, + statistics: localStatistics, + lastRefreshAt: providerResult.lastRefreshAt + }, + accountId, + platform, + source, + null, + scriptMeta + ) + } + + async _getBalanceFromScript(scriptConfig, accountId, platform) { + try { + const result = await balanceScriptService.execute({ + scriptBody: scriptConfig.scriptBody, + timeoutSeconds: scriptConfig.timeoutSeconds || 10, + variables: { + baseUrl: scriptConfig.baseUrl || '', + apiKey: scriptConfig.apiKey || '', + token: scriptConfig.token || '', + accountId, + platform, + extra: scriptConfig.extra || '' + } + }) + + const mapped = result?.mapped || {} + return { + status: mapped.status || 'error', + balance: typeof mapped.balance === 'number' ? mapped.balance : null, + currency: mapped.currency || 'USD', + quota: mapped.quota || null, + queryMethod: 'api', + rawData: mapped.rawData || result?.response?.data || null, + lastRefreshAt: new Date().toISOString(), + errorMessage: mapped.errorMessage || '' + } + } catch (error) { + return { + status: 'error', + balance: null, + currency: 'USD', + quota: null, + queryMethod: 'api', + rawData: null, + lastRefreshAt: new Date().toISOString(), + errorMessage: error.message || '脚本执行失败' + } + } + } + + async _getBalanceFromProvider(provider, account) { + try { + const result = await provider.queryBalance(account) + return { + status: 'success', + balance: typeof result?.balance === 'number' ? result.balance : null, + currency: result?.currency || 'USD', + quota: result?.quota || null, + queryMethod: result?.queryMethod || 'api', + rawData: result?.rawData || null, + lastRefreshAt: new Date().toISOString(), + errorMessage: '' + } + } catch (error) { + return { + status: 'error', + balance: null, + currency: 'USD', + quota: null, + queryMethod: 'api', + rawData: null, + lastRefreshAt: new Date().toISOString(), + errorMessage: error.message || '查询失败' + } + } + } + + async _getBalanceFromLocal(accountId, platform) { + const cached = await this.redis.getLocalBalance(platform, accountId) + if (cached && cached.statistics) { + return cached + } + + const statistics = await this._computeLocalStatistics(accountId) + const localBalance = { + status: 'success', + balance: null, + currency: 'USD', + statistics, + queryMethod: 'local', + lastCalculated: new Date().toISOString() + } + + await this.redis.setLocalBalance(platform, accountId, localBalance, this.LOCAL_TTL_SECONDS) + return localBalance + } + + async _computeLocalStatistics(accountId) { + const safeNumber = (value) => { + const num = Number(value) + return Number.isFinite(num) ? num : 0 + } + + try { + const usageStats = await this.redis.getAccountUsageStats(accountId) + const dailyCost = safeNumber(usageStats?.daily?.cost || 0) + const monthlyCost = await this._computeMonthlyCost(accountId) + const totalCost = await this._computeTotalCost(accountId) + + return { + totalCost, + dailyCost, + monthlyCost, + totalRequests: safeNumber(usageStats?.total?.requests || 0), + dailyRequests: safeNumber(usageStats?.daily?.requests || 0), + monthlyRequests: safeNumber(usageStats?.monthly?.requests || 0) + } + } catch (error) { + this.logger.debug(`本地统计计算失败: ${accountId}`, error) + return { + totalCost: 0, + dailyCost: 0, + monthlyCost: 0, + totalRequests: 0, + dailyRequests: 0, + monthlyRequests: 0 + } + } + } + + async _computeMonthlyCost(accountId) { + const tzDate = this.redis.getDateInTimezone(new Date()) + const currentMonth = `${tzDate.getUTCFullYear()}-${String(tzDate.getUTCMonth() + 1).padStart( + 2, + '0' + )}` + + const pattern = `account_usage:model:monthly:${accountId}:*:${currentMonth}` + return await this._sumModelCostsByKeysPattern(pattern) + } + + async _computeTotalCost(accountId) { + const pattern = `account_usage:model:monthly:${accountId}:*:*` + return await this._sumModelCostsByKeysPattern(pattern) + } + + async _sumModelCostsByKeysPattern(pattern) { + try { + const client = this.redis.getClientSafe() + let totalCost = 0 + let cursor = '0' + const scanCount = 200 + let iterations = 0 + const maxIterations = 2000 + + do { + const [nextCursor, keys] = await client.scan(cursor, 'MATCH', pattern, 'COUNT', scanCount) + cursor = nextCursor + iterations += 1 + + if (!keys || keys.length === 0) { + continue + } + + const pipeline = client.pipeline() + keys.forEach((key) => pipeline.hgetall(key)) + const results = await pipeline.exec() + + for (let i = 0; i < results.length; i += 1) { + const [, data] = results[i] || [] + if (!data || Object.keys(data).length === 0) { + continue + } + + const parts = String(keys[i]).split(':') + const model = parts[4] || 'unknown' + + const usage = { + input_tokens: parseInt(data.inputTokens || 0), + output_tokens: parseInt(data.outputTokens || 0), + cache_creation_input_tokens: parseInt(data.cacheCreateTokens || 0), + cache_read_input_tokens: parseInt(data.cacheReadTokens || 0) + } + + const costResult = CostCalculator.calculateCost(usage, model) + totalCost += costResult.costs.total || 0 + } + + if (iterations >= maxIterations) { + this.logger.warn(`SCAN 次数超过上限,停止汇总:${pattern}`) + break + } + } while (cursor !== '0') + + return totalCost + } catch (error) { + this.logger.debug(`汇总模型费用失败: ${pattern}`, error) + return 0 + } + } + + _buildQuotaFromLocal(account, statistics) { + if (!account || !Object.prototype.hasOwnProperty.call(account, 'dailyQuota')) { + return { balance: null, currency: null, quota: null } + } + + const dailyQuota = Number(account.dailyQuota || 0) + const used = Number(statistics?.dailyCost || 0) + + const resetAt = this._computeNextResetAt(account.quotaResetTime || '00:00') + + // 不限制 + if (!Number.isFinite(dailyQuota) || dailyQuota <= 0) { + return { + balance: null, + currency: 'USD', + quota: { + daily: Infinity, + used, + remaining: Infinity, + percentage: 0, + unlimited: true, + resetAt + } + } + } + + const remaining = Math.max(0, dailyQuota - used) + const percentage = dailyQuota > 0 ? (used / dailyQuota) * 100 : 0 + + return { + balance: remaining, + currency: 'USD', + quota: { + daily: dailyQuota, + used, + remaining, + resetAt, + percentage: Math.round(percentage * 100) / 100 + } + } + } + + _computeNextResetAt(resetTime) { + const now = new Date() + const tzNow = this.redis.getDateInTimezone(now) + const offsetMs = tzNow.getTime() - now.getTime() + + const [h, m] = String(resetTime || '00:00') + .split(':') + .map((n) => parseInt(n, 10)) + + const resetHour = Number.isFinite(h) ? h : 0 + const resetMinute = Number.isFinite(m) ? m : 0 + + const year = tzNow.getUTCFullYear() + const month = tzNow.getUTCMonth() + const day = tzNow.getUTCDate() + + let resetAtMs = Date.UTC(year, month, day, resetHour, resetMinute, 0, 0) - offsetMs + if (resetAtMs <= now.getTime()) { + resetAtMs += 24 * 60 * 60 * 1000 + } + + return new Date(resetAtMs).toISOString() + } + + _buildResponse(balanceData, accountId, platform, source, ttlSeconds = null, extraData = {}) { + const now = new Date() + + const amount = typeof balanceData.balance === 'number' ? balanceData.balance : null + const currency = balanceData.currency || 'USD' + + let cacheExpiresAt = null + if (source === 'cache') { + const ttl = + typeof ttlSeconds === 'number' && ttlSeconds > 0 ? ttlSeconds : this.CACHE_TTL_SECONDS + cacheExpiresAt = new Date(Date.now() + ttl * 1000).toISOString() + } + + return { + success: true, + data: { + accountId, + platform, + balance: + typeof amount === 'number' + ? { + amount, + currency, + formattedAmount: this._formatCurrency(amount, currency) + } + : null, + quota: balanceData.quota || null, + statistics: balanceData.statistics || {}, + source, + lastRefreshAt: balanceData.lastRefreshAt || now.toISOString(), + cacheExpiresAt, + status: balanceData.status || 'success', + error: balanceData.errorMessage || null, + ...(extraData && typeof extraData === 'object' ? extraData : {}) + } + } + } + + _formatCurrency(amount, currency = 'USD') { + try { + if (typeof amount !== 'number' || !Number.isFinite(amount)) { + return 'N/A' + } + return new Intl.NumberFormat('en-US', { style: 'currency', currency }).format(amount) + } catch (error) { + return `$${amount.toFixed(2)}` + } + } + + _parseBoolean(value) { + if (typeof value === 'boolean') { + return value + } + if (typeof value !== 'string') { + return null + } + const normalized = value.trim().toLowerCase() + if (normalized === 'true' || normalized === '1' || normalized === 'yes') { + return true + } + if (normalized === 'false' || normalized === '0' || normalized === 'no') { + return false + } + return null + } + + async _mapWithConcurrency(items, limit, mapper) { + const concurrency = Math.max(1, Number(limit) || 1) + const list = Array.isArray(items) ? items : [] + + const results = new Array(list.length) + let nextIndex = 0 + + const workers = new Array(Math.min(concurrency, list.length)).fill(null).map(async () => { + while (nextIndex < list.length) { + const currentIndex = nextIndex + nextIndex += 1 + results[currentIndex] = await mapper(list[currentIndex], currentIndex) + } + }) + + await Promise.all(workers) + return results + } +} + +const accountBalanceService = new AccountBalanceService() +module.exports = accountBalanceService +module.exports.AccountBalanceService = AccountBalanceService diff --git a/src/services/accountTestSchedulerService.js b/src/services/accountTestSchedulerService.js new file mode 100644 index 00000000..59b4c6af --- /dev/null +++ b/src/services/accountTestSchedulerService.js @@ -0,0 +1,420 @@ +/** + * 账户定时测试调度服务 + * 使用 node-cron 支持 crontab 表达式,为每个账户创建独立的定时任务 + */ + +const cron = require('node-cron') +const redis = require('../models/redis') +const logger = require('../utils/logger') + +class AccountTestSchedulerService { + constructor() { + // 存储每个账户的 cron 任务: Map + this.scheduledTasks = new Map() + // 定期刷新配置的间隔 (毫秒) + this.refreshIntervalMs = 60 * 1000 + this.refreshInterval = null + // 当前正在测试的账户 + this.testingAccounts = new Set() + // 是否已启动 + this.isStarted = false + } + + /** + * 验证 cron 表达式是否有效 + * @param {string} cronExpression - cron 表达式 + * @returns {boolean} + */ + validateCronExpression(cronExpression) { + // 长度检查(防止 DoS) + if (!cronExpression || cronExpression.length > 100) { + return false + } + return cron.validate(cronExpression) + } + + /** + * 启动调度器 + */ + async start() { + if (this.isStarted) { + logger.warn('⚠️ Account test scheduler is already running') + return + } + + this.isStarted = true + logger.info('🚀 Starting account test scheduler service (node-cron mode)') + + // 初始化所有已配置账户的定时任务 + await this._refreshAllTasks() + + // 定期刷新配置,以便动态添加/修改的配置能生效 + this.refreshInterval = setInterval(() => { + this._refreshAllTasks() + }, this.refreshIntervalMs) + + logger.info( + `📅 Account test scheduler started (refreshing configs every ${this.refreshIntervalMs / 1000}s)` + ) + } + + /** + * 停止调度器 + */ + stop() { + if (this.refreshInterval) { + clearInterval(this.refreshInterval) + this.refreshInterval = null + } + + // 停止所有 cron 任务 + for (const [accountKey, taskInfo] of this.scheduledTasks.entries()) { + taskInfo.task.stop() + logger.debug(`🛑 Stopped cron task for ${accountKey}`) + } + this.scheduledTasks.clear() + + this.isStarted = false + logger.info('🛑 Account test scheduler stopped') + } + + /** + * 刷新所有账户的定时任务 + * @private + */ + async _refreshAllTasks() { + try { + const platforms = ['claude', 'gemini', 'openai'] + const activeAccountKeys = new Set() + + // 并行加载所有平台的配置 + const allEnabledAccounts = await Promise.all( + platforms.map((platform) => + redis + .getEnabledTestAccounts(platform) + .then((accounts) => accounts.map((acc) => ({ ...acc, platform }))) + .catch((error) => { + logger.warn(`⚠️ Failed to load test accounts for platform ${platform}:`, error) + return [] + }) + ) + ) + + // 展平平台数据 + const flatAccounts = allEnabledAccounts.flat() + + for (const { accountId, cronExpression, model, platform } of flatAccounts) { + if (!cronExpression) { + logger.warn( + `⚠️ Account ${accountId} (${platform}) has no valid cron expression, skipping` + ) + continue + } + + const accountKey = `${platform}:${accountId}` + activeAccountKeys.add(accountKey) + + // 检查是否需要更新任务 + const existingTask = this.scheduledTasks.get(accountKey) + if (existingTask) { + // 如果 cron 表达式和模型都没变,不需要更新 + if (existingTask.cronExpression === cronExpression && existingTask.model === model) { + continue + } + // 配置变了,停止旧任务 + existingTask.task.stop() + logger.info(`🔄 Updating cron task for ${accountKey}: ${cronExpression}, model: ${model}`) + } else { + logger.info(`➕ Creating cron task for ${accountKey}: ${cronExpression}, model: ${model}`) + } + + // 创建新的 cron 任务 + this._createCronTask(accountId, platform, cronExpression, model) + } + + // 清理已删除或禁用的账户任务 + for (const [accountKey, taskInfo] of this.scheduledTasks.entries()) { + if (!activeAccountKeys.has(accountKey)) { + taskInfo.task.stop() + this.scheduledTasks.delete(accountKey) + logger.info(`➖ Removed cron task for ${accountKey} (disabled or deleted)`) + } + } + } catch (error) { + logger.error('❌ Error refreshing account test tasks:', error) + } + } + + /** + * 为单个账户创建 cron 任务 + * @param {string} accountId + * @param {string} platform + * @param {string} cronExpression + * @param {string} model - 测试使用的模型 + * @private + */ + _createCronTask(accountId, platform, cronExpression, model) { + const accountKey = `${platform}:${accountId}` + + // 验证 cron 表达式 + if (!this.validateCronExpression(cronExpression)) { + logger.error(`❌ Invalid cron expression for ${accountKey}: ${cronExpression}`) + return + } + + const task = cron.schedule( + cronExpression, + async () => { + await this._runAccountTest(accountId, platform, model) + }, + { + scheduled: true, + timezone: process.env.TZ || 'Asia/Shanghai' + } + ) + + this.scheduledTasks.set(accountKey, { + task, + cronExpression, + model, + accountId, + platform + }) + } + + /** + * 执行单个账户测试 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {string} model - 测试使用的模型 + * @private + */ + async _runAccountTest(accountId, platform, model) { + const accountKey = `${platform}:${accountId}` + + // 避免重复测试 + if (this.testingAccounts.has(accountKey)) { + logger.debug(`⏳ Account ${accountKey} is already being tested, skipping`) + return + } + + this.testingAccounts.add(accountKey) + + try { + logger.info( + `🧪 Running scheduled test for ${platform} account: ${accountId} (model: ${model})` + ) + + let testResult + + // 根据平台调用对应的测试方法 + switch (platform) { + case 'claude': + testResult = await this._testClaudeAccount(accountId, model) + break + case 'gemini': + testResult = await this._testGeminiAccount(accountId, model) + break + case 'openai': + testResult = await this._testOpenAIAccount(accountId, model) + break + default: + testResult = { + success: false, + error: `Unsupported platform: ${platform}`, + timestamp: new Date().toISOString() + } + } + + // 保存测试结果 + await redis.saveAccountTestResult(accountId, platform, testResult) + + // 更新最后测试时间 + await redis.setAccountLastTestTime(accountId, platform) + + // 记录日志 + if (testResult.success) { + logger.info( + `✅ Scheduled test passed for ${platform} account ${accountId} (${testResult.latencyMs}ms)` + ) + } else { + logger.warn( + `❌ Scheduled test failed for ${platform} account ${accountId}: ${testResult.error}` + ) + } + + return testResult + } catch (error) { + logger.error(`❌ Error testing ${platform} account ${accountId}:`, error) + + const errorResult = { + success: false, + error: error.message, + timestamp: new Date().toISOString() + } + + await redis.saveAccountTestResult(accountId, platform, errorResult) + await redis.setAccountLastTestTime(accountId, platform) + + return errorResult + } finally { + this.testingAccounts.delete(accountKey) + } + } + + /** + * 测试 Claude 账户 + * @param {string} accountId + * @param {string} model - 测试使用的模型 + * @private + */ + async _testClaudeAccount(accountId, model) { + const claudeRelayService = require('./claudeRelayService') + return await claudeRelayService.testAccountConnectionSync(accountId, model) + } + + /** + * 测试 Gemini 账户 + * @param {string} _accountId + * @param {string} _model + * @private + */ + async _testGeminiAccount(_accountId, _model) { + // Gemini 测试暂时返回未实现 + return { + success: false, + error: 'Gemini scheduled test not implemented yet', + timestamp: new Date().toISOString() + } + } + + /** + * 测试 OpenAI 账户 + * @param {string} _accountId + * @param {string} _model + * @private + */ + async _testOpenAIAccount(_accountId, _model) { + // OpenAI 测试暂时返回未实现 + return { + success: false, + error: 'OpenAI scheduled test not implemented yet', + timestamp: new Date().toISOString() + } + } + + /** + * 手动触发账户测试 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {string} model - 测试使用的模型 + * @returns {Promise} 测试结果 + */ + async triggerTest(accountId, platform, model = 'claude-sonnet-4-5-20250929') { + logger.info(`🎯 Manual test triggered for ${platform} account: ${accountId} (model: ${model})`) + return await this._runAccountTest(accountId, platform, model) + } + + /** + * 获取账户测试历史 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 测试历史 + */ + async getTestHistory(accountId, platform) { + return await redis.getAccountTestHistory(accountId, platform) + } + + /** + * 获取账户测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} + */ + async getTestConfig(accountId, platform) { + return await redis.getAccountTestConfig(accountId, platform) + } + + /** + * 设置账户测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {Object} testConfig - 测试配置 { enabled: boolean, cronExpression: string, model: string } + * @returns {Promise} + */ + async setTestConfig(accountId, platform, testConfig) { + // 验证 cron 表达式 + if (testConfig.cronExpression && !this.validateCronExpression(testConfig.cronExpression)) { + throw new Error(`Invalid cron expression: ${testConfig.cronExpression}`) + } + + await redis.saveAccountTestConfig(accountId, platform, testConfig) + logger.info( + `📝 Test config updated for ${platform} account ${accountId}: enabled=${testConfig.enabled}, cronExpression=${testConfig.cronExpression}, model=${testConfig.model}` + ) + + // 立即刷新任务,使配置立即生效 + if (this.isStarted) { + await this._refreshAllTasks() + } + } + + /** + * 更新单个账户的定时任务(配置变更时调用) + * @param {string} accountId + * @param {string} platform + */ + async refreshAccountTask(accountId, platform) { + if (!this.isStarted) { + return + } + + const accountKey = `${platform}:${accountId}` + const testConfig = await redis.getAccountTestConfig(accountId, platform) + + // 停止现有任务 + const existingTask = this.scheduledTasks.get(accountKey) + if (existingTask) { + existingTask.task.stop() + this.scheduledTasks.delete(accountKey) + } + + // 如果启用且有有效的 cron 表达式,创建新任务 + if (testConfig?.enabled && testConfig?.cronExpression) { + this._createCronTask(accountId, platform, testConfig.cronExpression, testConfig.model) + logger.info( + `🔄 Refreshed cron task for ${accountKey}: ${testConfig.cronExpression}, model: ${testConfig.model}` + ) + } + } + + /** + * 获取调度器状态 + * @returns {Object} + */ + getStatus() { + const tasks = [] + for (const [accountKey, taskInfo] of this.scheduledTasks.entries()) { + tasks.push({ + accountKey, + accountId: taskInfo.accountId, + platform: taskInfo.platform, + cronExpression: taskInfo.cronExpression, + model: taskInfo.model + }) + } + + return { + running: this.isStarted, + refreshIntervalMs: this.refreshIntervalMs, + scheduledTasksCount: this.scheduledTasks.size, + scheduledTasks: tasks, + currentlyTesting: Array.from(this.testingAccounts) + } + } +} + +// 单例模式 +const accountTestSchedulerService = new AccountTestSchedulerService() + +module.exports = accountTestSchedulerService diff --git a/src/services/apiKeyService.js b/src/services/apiKeyService.js index 0e9e7597..771f973b 100644 --- a/src/services/apiKeyService.js +++ b/src/services/apiKeyService.js @@ -37,6 +37,51 @@ const ACCOUNT_CATEGORY_MAP = { droid: 'droid' } +/** + * 规范化权限数据,兼容旧格式(字符串)和新格式(数组) + * @param {string|array} permissions - 权限数据 + * @returns {array} - 权限数组,空数组表示全部服务 + */ +function normalizePermissions(permissions) { + if (!permissions) { + return [] // 空 = 全部服务 + } + if (Array.isArray(permissions)) { + return permissions + } + // 尝试解析 JSON 字符串(新格式存储) + if (typeof permissions === 'string') { + if (permissions.startsWith('[')) { + try { + const parsed = JSON.parse(permissions) + if (Array.isArray(parsed)) { + return parsed + } + } catch (e) { + // 解析失败,继续处理为普通字符串 + } + } + // 旧格式 'all' 转为空数组 + if (permissions === 'all') { + return [] + } + // 旧单个字符串转为数组 + return [permissions] + } + return [] +} + +/** + * 检查是否有访问特定服务的权限 + * @param {string|array} permissions - 权限数据 + * @param {string} service - 服务名称(claude/gemini/openai/droid) + * @returns {boolean} - 是否有权限 + */ +function hasPermission(permissions, service) { + const perms = normalizePermissions(permissions) + return perms.length === 0 || perms.includes(service) // 空数组 = 全部服务 +} + function normalizeAccountTypeKey(type) { if (!type) { return null @@ -89,7 +134,7 @@ class ApiKeyService { azureOpenaiAccountId = null, bedrockAccountId = null, // 添加 Bedrock 账号ID支持 droidAccountId = null, - permissions = 'all', // 可选值:'claude'、'gemini'、'openai'、'droid' 或 'all' + permissions = [], // 数组格式,空数组表示全部服务,如 ['claude', 'gemini'] isActive = true, concurrencyLimit = 0, rateLimitWindow = null, @@ -132,7 +177,7 @@ class ApiKeyService { azureOpenaiAccountId: azureOpenaiAccountId || '', bedrockAccountId: bedrockAccountId || '', // 添加 Bedrock 账号ID droidAccountId: droidAccountId || '', - permissions: permissions || 'all', + permissions: JSON.stringify(normalizePermissions(permissions)), enableModelRestriction: String(enableModelRestriction), restrictedModels: JSON.stringify(restrictedModels || []), enableClientRestriction: String(enableClientRestriction || false), @@ -186,7 +231,7 @@ class ApiKeyService { azureOpenaiAccountId: keyData.azureOpenaiAccountId, bedrockAccountId: keyData.bedrockAccountId, // 添加 Bedrock 账号ID droidAccountId: keyData.droidAccountId, - permissions: keyData.permissions, + permissions: normalizePermissions(keyData.permissions), enableModelRestriction: keyData.enableModelRestriction === 'true', restrictedModels: JSON.parse(keyData.restrictedModels), enableClientRestriction: keyData.enableClientRestriction === 'true', @@ -338,7 +383,7 @@ class ApiKeyService { azureOpenaiAccountId: keyData.azureOpenaiAccountId, bedrockAccountId: keyData.bedrockAccountId, // 添加 Bedrock 账号ID droidAccountId: keyData.droidAccountId, - permissions: keyData.permissions || 'all', + permissions: normalizePermissions(keyData.permissions), tokenLimit: parseInt(keyData.tokenLimit), concurrencyLimit: parseInt(keyData.concurrencyLimit || 0), rateLimitWindow: parseInt(keyData.rateLimitWindow || 0), @@ -467,7 +512,7 @@ class ApiKeyService { azureOpenaiAccountId: keyData.azureOpenaiAccountId, bedrockAccountId: keyData.bedrockAccountId, droidAccountId: keyData.droidAccountId, - permissions: keyData.permissions || 'all', + permissions: normalizePermissions(keyData.permissions), tokenLimit: parseInt(keyData.tokenLimit), concurrencyLimit: parseInt(keyData.concurrencyLimit || 0), rateLimitWindow: parseInt(keyData.rateLimitWindow || 0), @@ -525,7 +570,7 @@ class ApiKeyService { key.isActive = key.isActive === 'true' key.enableModelRestriction = key.enableModelRestriction === 'true' key.enableClientRestriction = key.enableClientRestriction === 'true' - key.permissions = key.permissions || 'all' // 兼容旧数据 + key.permissions = normalizePermissions(key.permissions) key.dailyCostLimit = parseFloat(key.dailyCostLimit || 0) key.totalCostLimit = parseFloat(key.totalCostLimit || 0) key.weeklyOpusCostLimit = parseFloat(key.weeklyOpusCostLimit || 0) @@ -1568,7 +1613,7 @@ class ApiKeyService { userId: keyData.userId, userUsername: keyData.userUsername, createdBy: keyData.createdBy, - permissions: keyData.permissions, + permissions: normalizePermissions(keyData.permissions), dailyCostLimit: parseFloat(keyData.dailyCostLimit || 0), totalCostLimit: parseFloat(keyData.totalCostLimit || 0), // 所有平台账户绑定字段 @@ -1820,4 +1865,8 @@ const apiKeyService = new ApiKeyService() // 为了方便其他服务调用,导出 recordUsage 方法 apiKeyService.recordUsageMetrics = apiKeyService.recordUsage.bind(apiKeyService) +// 导出权限辅助函数供路由使用 +apiKeyService.hasPermission = hasPermission +apiKeyService.normalizePermissions = normalizePermissions + module.exports = apiKeyService diff --git a/src/services/balanceProviders/baseBalanceProvider.js b/src/services/balanceProviders/baseBalanceProvider.js new file mode 100644 index 00000000..ececd2e5 --- /dev/null +++ b/src/services/balanceProviders/baseBalanceProvider.js @@ -0,0 +1,133 @@ +const axios = require('axios') +const logger = require('../../utils/logger') +const ProxyHelper = require('../../utils/proxyHelper') + +/** + * Provider 抽象基类 + * 各平台 Provider 需继承并实现 queryBalance(account) + */ +class BaseBalanceProvider { + constructor(platform) { + this.platform = platform + this.logger = logger + } + + /** + * 查询余额(抽象方法) + * @param {object} account - 账户对象 + * @returns {Promise} + * 形如: + * { + * balance: number|null, + * currency?: string, + * quota?: { daily, used, remaining, resetAt, percentage, unlimited? }, + * queryMethod?: 'api'|'field'|'local', + * rawData?: any + * } + */ + async queryBalance(_account) { + throw new Error('queryBalance 方法必须由子类实现') + } + + /** + * 通用 HTTP 请求方法(支持代理) + * @param {string} url + * @param {object} options + * @param {object} account + */ + async makeRequest(url, options = {}, account = {}) { + const config = { + url, + method: options.method || 'GET', + headers: options.headers || {}, + timeout: options.timeout || 15000, + data: options.data, + params: options.params, + responseType: options.responseType + } + + const proxyConfig = account.proxyConfig || account.proxy + if (proxyConfig) { + const agent = ProxyHelper.createProxyAgent(proxyConfig) + if (agent) { + config.httpAgent = agent + config.httpsAgent = agent + config.proxy = false + } + } + + try { + const response = await axios(config) + return { + success: true, + data: response.data, + status: response.status, + headers: response.headers + } + } catch (error) { + const status = error.response?.status + const message = error.response?.data?.message || error.message || '请求失败' + this.logger.debug(`余额 Provider HTTP 请求失败: ${url} (${this.platform})`, { + status, + message + }) + return { success: false, status, error: message } + } + } + + /** + * 从账户字段读取 dailyQuota / dailyUsage(通用降级方案) + * 注意:部分平台 dailyUsage 字段可能不是实时值,最终以 AccountBalanceService 的本地统计为准 + */ + readQuotaFromFields(account) { + const dailyQuota = Number(account?.dailyQuota || 0) + const dailyUsage = Number(account?.dailyUsage || 0) + + // 无限制 + if (!Number.isFinite(dailyQuota) || dailyQuota <= 0) { + return { + balance: null, + currency: 'USD', + quota: { + daily: Infinity, + used: Number.isFinite(dailyUsage) ? dailyUsage : 0, + remaining: Infinity, + percentage: 0, + unlimited: true + }, + queryMethod: 'field' + } + } + + const used = Number.isFinite(dailyUsage) ? dailyUsage : 0 + const remaining = Math.max(0, dailyQuota - used) + const percentage = dailyQuota > 0 ? (used / dailyQuota) * 100 : 0 + + return { + balance: remaining, + currency: 'USD', + quota: { + daily: dailyQuota, + used, + remaining, + percentage: Math.round(percentage * 100) / 100 + }, + queryMethod: 'field' + } + } + + parseCurrency(data) { + return data?.currency || data?.Currency || 'USD' + } + + async safeExecute(fn, fallbackValue = null) { + try { + return await fn() + } catch (error) { + this.logger.error(`余额 Provider 执行失败: ${this.platform}`, error) + return fallbackValue + } + } +} + +module.exports = BaseBalanceProvider diff --git a/src/services/balanceProviders/claudeBalanceProvider.js b/src/services/balanceProviders/claudeBalanceProvider.js new file mode 100644 index 00000000..89783028 --- /dev/null +++ b/src/services/balanceProviders/claudeBalanceProvider.js @@ -0,0 +1,30 @@ +const BaseBalanceProvider = require('./baseBalanceProvider') +const claudeAccountService = require('../claudeAccountService') + +class ClaudeBalanceProvider extends BaseBalanceProvider { + constructor() { + super('claude') + } + + /** + * Claude(OAuth):优先尝试获取 OAuth usage(用于配额/使用信息),不强行提供余额金额 + */ + async queryBalance(account) { + this.logger.debug(`查询 Claude 余额(OAuth usage): ${account?.id}`) + + // 仅 OAuth 账户可用;失败时降级 + const usageData = await claudeAccountService.fetchOAuthUsage(account.id).catch(() => null) + if (!usageData) { + return { balance: null, currency: 'USD', queryMethod: 'local' } + } + + return { + balance: null, + currency: 'USD', + queryMethod: 'api', + rawData: usageData + } + } +} + +module.exports = ClaudeBalanceProvider diff --git a/src/services/balanceProviders/claudeConsoleBalanceProvider.js b/src/services/balanceProviders/claudeConsoleBalanceProvider.js new file mode 100644 index 00000000..f5441047 --- /dev/null +++ b/src/services/balanceProviders/claudeConsoleBalanceProvider.js @@ -0,0 +1,14 @@ +const BaseBalanceProvider = require('./baseBalanceProvider') + +class ClaudeConsoleBalanceProvider extends BaseBalanceProvider { + constructor() { + super('claude-console') + } + + async queryBalance(account) { + this.logger.debug(`查询 Claude Console 余额(字段): ${account?.id}`) + return this.readQuotaFromFields(account) + } +} + +module.exports = ClaudeConsoleBalanceProvider diff --git a/src/services/balanceProviders/genericBalanceProvider.js b/src/services/balanceProviders/genericBalanceProvider.js new file mode 100644 index 00000000..6b3efe2b --- /dev/null +++ b/src/services/balanceProviders/genericBalanceProvider.js @@ -0,0 +1,23 @@ +const BaseBalanceProvider = require('./baseBalanceProvider') + +class GenericBalanceProvider extends BaseBalanceProvider { + constructor(platform) { + super(platform) + } + + async queryBalance(account) { + this.logger.debug(`${this.platform} 暂无专用余额 API,实现降级策略`) + + if (account && Object.prototype.hasOwnProperty.call(account, 'dailyQuota')) { + return this.readQuotaFromFields(account) + } + + return { + balance: null, + currency: 'USD', + queryMethod: 'local' + } + } +} + +module.exports = GenericBalanceProvider diff --git a/src/services/balanceProviders/index.js b/src/services/balanceProviders/index.js new file mode 100644 index 00000000..d55fda5b --- /dev/null +++ b/src/services/balanceProviders/index.js @@ -0,0 +1,24 @@ +const ClaudeBalanceProvider = require('./claudeBalanceProvider') +const ClaudeConsoleBalanceProvider = require('./claudeConsoleBalanceProvider') +const OpenAIResponsesBalanceProvider = require('./openaiResponsesBalanceProvider') +const GenericBalanceProvider = require('./genericBalanceProvider') + +function registerAllProviders(balanceService) { + // Claude + balanceService.registerProvider('claude', new ClaudeBalanceProvider()) + balanceService.registerProvider('claude-console', new ClaudeConsoleBalanceProvider()) + + // OpenAI / Codex + balanceService.registerProvider('openai-responses', new OpenAIResponsesBalanceProvider()) + balanceService.registerProvider('openai', new GenericBalanceProvider('openai')) + balanceService.registerProvider('azure_openai', new GenericBalanceProvider('azure_openai')) + + // 其他平台(降级) + balanceService.registerProvider('gemini', new GenericBalanceProvider('gemini')) + balanceService.registerProvider('gemini-api', new GenericBalanceProvider('gemini-api')) + balanceService.registerProvider('bedrock', new GenericBalanceProvider('bedrock')) + balanceService.registerProvider('droid', new GenericBalanceProvider('droid')) + balanceService.registerProvider('ccr', new GenericBalanceProvider('ccr')) +} + +module.exports = { registerAllProviders } diff --git a/src/services/balanceProviders/openaiResponsesBalanceProvider.js b/src/services/balanceProviders/openaiResponsesBalanceProvider.js new file mode 100644 index 00000000..9ff8433e --- /dev/null +++ b/src/services/balanceProviders/openaiResponsesBalanceProvider.js @@ -0,0 +1,54 @@ +const BaseBalanceProvider = require('./baseBalanceProvider') + +class OpenAIResponsesBalanceProvider extends BaseBalanceProvider { + constructor() { + super('openai-responses') + } + + /** + * OpenAI-Responses: + * - 优先使用 dailyQuota 字段(如果配置了额度) + * - 可选:尝试调用兼容 API(不同服务商实现不一,失败自动降级) + */ + async queryBalance(account) { + this.logger.debug(`查询 OpenAI Responses 余额: ${account?.id}`) + + // 配置了额度时直接返回(字段法) + if (account?.dailyQuota && Number(account.dailyQuota) > 0) { + return this.readQuotaFromFields(account) + } + + // 尝试调用 usage 接口(兼容性不保证) + if (account?.apiKey && account?.baseApi) { + const baseApi = String(account.baseApi).replace(/\/$/, '') + const response = await this.makeRequest( + `${baseApi}/v1/usage`, + { + method: 'GET', + headers: { + Authorization: `Bearer ${account.apiKey}`, + 'Content-Type': 'application/json' + } + }, + account + ) + + if (response.success) { + return { + balance: null, + currency: this.parseCurrency(response.data), + queryMethod: 'api', + rawData: response.data + } + } + } + + return { + balance: null, + currency: 'USD', + queryMethod: 'local' + } + } +} + +module.exports = OpenAIResponsesBalanceProvider diff --git a/src/services/balanceScriptService.js b/src/services/balanceScriptService.js new file mode 100644 index 00000000..5bf06801 --- /dev/null +++ b/src/services/balanceScriptService.js @@ -0,0 +1,161 @@ +const vm = require('vm') +const axios = require('axios') +const { isBalanceScriptEnabled } = require('../utils/featureFlags') + +/** + * 可配置脚本余额查询执行器 + * - 脚本格式:({ request: {...}, extractor: function(response){...} }) + * - 模板变量:{{baseUrl}}, {{apiKey}}, {{token}}, {{accountId}}, {{platform}}, {{extra}} + */ +class BalanceScriptService { + /** + * 执行脚本:返回标准余额结构 + 原始响应 + * @param {object} options + * - scriptBody: string + * - variables: Record + * - timeoutSeconds: number + */ + async execute(options = {}) { + if (!isBalanceScriptEnabled()) { + const error = new Error('余额脚本功能已禁用(可通过 BALANCE_SCRIPT_ENABLED=true 启用)') + error.code = 'BALANCE_SCRIPT_DISABLED' + throw error + } + + const scriptBody = options.scriptBody?.trim() + if (!scriptBody) { + throw new Error('脚本内容为空') + } + + const timeoutMs = Math.max(1, (options.timeoutSeconds || 10) * 1000) + const sandbox = { + console, + Math, + Date + } + + let scriptResult + try { + const wrapped = scriptBody.startsWith('(') ? scriptBody : `(${scriptBody})` + const script = new vm.Script(wrapped) + scriptResult = script.runInNewContext(sandbox, { timeout: timeoutMs }) + } catch (error) { + throw new Error(`脚本解析失败: ${error.message}`) + } + + if (!scriptResult || typeof scriptResult !== 'object') { + throw new Error('脚本返回格式无效(需返回 { request, extractor })') + } + + const variables = options.variables || {} + const request = this.applyTemplates(scriptResult.request || {}, variables) + const { extractor } = scriptResult + + if (!request?.url || typeof request.url !== 'string') { + throw new Error('脚本 request.url 不能为空') + } + + if (typeof extractor !== 'function') { + throw new Error('脚本 extractor 必须是函数') + } + + const axiosConfig = { + url: request.url, + method: (request.method || 'GET').toUpperCase(), + headers: request.headers || {}, + timeout: timeoutMs + } + + if (request.params) { + axiosConfig.params = request.params + } + if (request.body || request.data) { + axiosConfig.data = request.body || request.data + } + + let httpResponse + try { + httpResponse = await axios(axiosConfig) + } catch (error) { + const { response } = error || {} + const { status, data } = response || {} + throw new Error( + `请求失败: ${status || ''} ${error.message}${data ? ` | ${JSON.stringify(data)}` : ''}` + ) + } + + const responseData = httpResponse?.data + + let extracted = {} + try { + extracted = extractor(responseData) || {} + } catch (error) { + throw new Error(`extractor 执行失败: ${error.message}`) + } + + const mapped = this.mapExtractorResult(extracted, responseData) + return { + mapped, + extracted, + response: { + status: httpResponse?.status, + headers: httpResponse?.headers, + data: responseData + } + } + } + + applyTemplates(value, variables) { + if (typeof value === 'string') { + return value.replace(/{{(\w+)}}/g, (_, key) => { + const trimmed = key.trim() + return variables[trimmed] !== undefined ? String(variables[trimmed]) : '' + }) + } + if (Array.isArray(value)) { + return value.map((item) => this.applyTemplates(item, variables)) + } + if (value && typeof value === 'object') { + const result = {} + Object.keys(value).forEach((k) => { + result[k] = this.applyTemplates(value[k], variables) + }) + return result + } + return value + } + + mapExtractorResult(result = {}, responseData) { + const isValid = result.isValid !== false + const remaining = Number(result.remaining) + const total = Number(result.total) + const used = Number(result.used) + const currency = result.unit || 'USD' + + const quota = + Number.isFinite(total) || Number.isFinite(used) + ? { + total: Number.isFinite(total) ? total : null, + used: Number.isFinite(used) ? used : null, + remaining: Number.isFinite(remaining) ? remaining : null, + percentage: + Number.isFinite(total) && total > 0 && Number.isFinite(used) + ? (used / total) * 100 + : null + } + : null + + return { + status: isValid ? 'success' : 'error', + errorMessage: isValid ? '' : result.invalidMessage || '套餐无效', + balance: Number.isFinite(remaining) ? remaining : null, + currency, + quota, + planName: result.planName || null, + extra: result.extra || null, + rawData: responseData || result.raw + } + } +} + +module.exports = new BalanceScriptService() diff --git a/src/services/claudeAccountService.js b/src/services/claudeAccountService.js index 77630364..a2f8e6d2 100644 --- a/src/services/claudeAccountService.js +++ b/src/services/claudeAccountService.js @@ -91,7 +91,9 @@ class ClaudeAccountService { useUnifiedClientId = false, // 是否使用统一的客户端标识 unifiedClientId = '', // 统一的客户端标识 expiresAt = null, // 账户订阅到期时间 - extInfo = null // 额外扩展信息 + extInfo = null, // 额外扩展信息 + maxConcurrency = 0, // 账户级用户消息串行队列:0=使用全局配置,>0=强制启用串行 + interceptWarmup = false // 拦截预热请求(标题生成、Warmup等) } = options const accountId = uuidv4() @@ -136,7 +138,11 @@ class ClaudeAccountService { // 账户订阅到期时间 subscriptionExpiresAt: expiresAt || '', // 扩展信息 - extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '' + extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '', + // 账户级用户消息串行队列限制 + maxConcurrency: maxConcurrency.toString(), + // 拦截预热请求 + interceptWarmup: interceptWarmup.toString() } } else { // 兼容旧格式 @@ -168,7 +174,11 @@ class ClaudeAccountService { // 账户订阅到期时间 subscriptionExpiresAt: expiresAt || '', // 扩展信息 - extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '' + extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '', + // 账户级用户消息串行队列限制 + maxConcurrency: maxConcurrency.toString(), + // 拦截预热请求 + interceptWarmup: interceptWarmup.toString() } } @@ -216,7 +226,8 @@ class ClaudeAccountService { useUnifiedUserAgent, useUnifiedClientId, unifiedClientId, - extInfo: normalizedExtInfo + extInfo: normalizedExtInfo, + interceptWarmup } } @@ -574,7 +585,11 @@ class ClaudeAccountService { // 添加停止原因 stoppedReason: account.stoppedReason || null, // 扩展信息 - extInfo: parsedExtInfo + extInfo: parsedExtInfo, + // 账户级用户消息串行队列限制 + maxConcurrency: parseInt(account.maxConcurrency || '0', 10), + // 拦截预热请求 + interceptWarmup: account.interceptWarmup === 'true' } }) ) @@ -666,7 +681,9 @@ class ClaudeAccountService { 'useUnifiedClientId', 'unifiedClientId', 'subscriptionExpiresAt', - 'extInfo' + 'extInfo', + 'maxConcurrency', + 'interceptWarmup' ] const updatedData = { ...accountData } let shouldClearAutoStopFields = false @@ -681,7 +698,7 @@ class ClaudeAccountService { updatedData[field] = this._encryptSensitiveData(value) } else if (field === 'proxy') { updatedData[field] = value ? JSON.stringify(value) : '' - } else if (field === 'priority') { + } else if (field === 'priority' || field === 'maxConcurrency') { updatedData[field] = value.toString() } else if (field === 'subscriptionInfo') { // 处理订阅信息更新 diff --git a/src/services/claudeConsoleAccountService.js b/src/services/claudeConsoleAccountService.js index 5ffc5d46..a46af870 100644 --- a/src/services/claudeConsoleAccountService.js +++ b/src/services/claudeConsoleAccountService.js @@ -68,7 +68,8 @@ class ClaudeConsoleAccountService { dailyQuota = 0, // 每日额度限制(美元),0表示不限制 quotaResetTime = '00:00', // 额度重置时间(HH:mm格式) maxConcurrentTasks = 0, // 最大并发任务数,0表示无限制 - disableAutoProtection = false // 是否关闭自动防护(429/401/400/529 不自动禁用) + disableAutoProtection = false, // 是否关闭自动防护(429/401/400/529 不自动禁用) + interceptWarmup = false // 拦截预热请求(标题生成、Warmup等) } = options // 验证必填字段 @@ -117,7 +118,8 @@ class ClaudeConsoleAccountService { quotaResetTime, // 额度重置时间 quotaStoppedAt: '', // 因额度停用的时间 maxConcurrentTasks: maxConcurrentTasks.toString(), // 最大并发任务数,0表示无限制 - disableAutoProtection: disableAutoProtection.toString() // 关闭自动防护 + disableAutoProtection: disableAutoProtection.toString(), // 关闭自动防护 + interceptWarmup: interceptWarmup.toString() // 拦截预热请求 } const client = redis.getClientSafe() @@ -156,6 +158,7 @@ class ClaudeConsoleAccountService { quotaStoppedAt: null, maxConcurrentTasks, // 新增:返回并发限制配置 disableAutoProtection, // 新增:返回自动防护开关 + interceptWarmup, // 新增:返回预热请求拦截开关 activeTaskCount: 0 // 新增:新建账户当前并发数为0 } } @@ -217,7 +220,9 @@ class ClaudeConsoleAccountService { // 并发控制相关 maxConcurrentTasks: parseInt(accountData.maxConcurrentTasks) || 0, activeTaskCount, - disableAutoProtection: accountData.disableAutoProtection === 'true' + disableAutoProtection: accountData.disableAutoProtection === 'true', + // 拦截预热请求 + interceptWarmup: accountData.interceptWarmup === 'true' }) } } @@ -375,6 +380,9 @@ class ClaudeConsoleAccountService { if (updates.disableAutoProtection !== undefined) { updatedData.disableAutoProtection = updates.disableAutoProtection.toString() } + if (updates.interceptWarmup !== undefined) { + updatedData.interceptWarmup = updates.interceptWarmup.toString() + } // ✅ 直接保存 subscriptionExpiresAt(如果提供) // Claude Console 没有 token 刷新逻辑,不会覆盖此字段 diff --git a/src/services/claudeRelayService.js b/src/services/claudeRelayService.js index 36671fee..8fb90685 100644 --- a/src/services/claudeRelayService.js +++ b/src/services/claudeRelayService.js @@ -210,7 +210,17 @@ class ClaudeRelayService { logger.error('❌ accountId missing for queue lock in relayRequest') throw new Error('accountId missing for queue lock') } - const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + // 获取账户信息以检查账户级串行队列配置 + const accountForQueue = await claudeAccountService.getAccount(accountId) + const accountConfig = accountForQueue + ? { maxConcurrency: parseInt(accountForQueue.maxConcurrency || '0', 10) } + : null + const queueResult = await userMessageQueueService.acquireQueueLock( + accountId, + null, + null, + accountConfig + ) if (!queueResult.acquired && !queueResult.skipped) { // 区分 Redis 后端错误和队列超时 const isBackendError = queueResult.error === 'queue_backend_error' @@ -323,17 +333,46 @@ class ClaudeRelayService { } // 发送请求到Claude API(传入回调以获取请求对象) - const response = await this._makeClaudeRequest( - processedBody, - accessToken, - proxyAgent, - clientHeaders, - accountId, - (req) => { - upstreamRequest = req - }, - options - ) + // 🔄 403 重试机制:仅对 claude-official 类型账户(OAuth 或 Setup Token) + const maxRetries = this._shouldRetryOn403(accountType) ? 2 : 0 + let retryCount = 0 + let response + let shouldRetry = false + + do { + response = await this._makeClaudeRequest( + processedBody, + accessToken, + proxyAgent, + clientHeaders, + accountId, + (req) => { + upstreamRequest = req + }, + options + ) + + // 检查是否需要重试 403 + shouldRetry = response.statusCode === 403 && retryCount < maxRetries + if (shouldRetry) { + retryCount++ + logger.warn( + `🔄 403 error for account ${accountId}, retry ${retryCount}/${maxRetries} after 2s` + ) + await this._sleep(2000) + } + } while (shouldRetry) + + // 如果进行了重试,记录最终结果 + if (retryCount > 0) { + if (response.statusCode === 403) { + logger.error(`🚫 403 error persists for account ${accountId} after ${retryCount} retries`) + } else { + logger.info( + `✅ 403 retry successful for account ${accountId} on attempt ${retryCount}, got status ${response.statusCode}` + ) + } + } // 📬 请求已发送成功,立即释放队列锁(无需等待响应处理完成) // 因为 Claude API 限流基于请求发送时刻计算(RPM),不是请求完成时刻 @@ -398,9 +437,10 @@ class ClaudeRelayService { } } // 检查是否为403状态码(禁止访问) + // 注意:如果进行了重试,retryCount > 0;这里的 403 是重试后最终的结果 else if (response.statusCode === 403) { logger.error( - `🚫 Forbidden error (403) detected for account ${accountId}, marking as blocked` + `🚫 Forbidden error (403) detected for account ${accountId}${retryCount > 0 ? ` after ${retryCount} retries` : ''}, marking as blocked` ) await unifiedClaudeScheduler.markAccountBlocked(accountId, accountType, sessionHash) } @@ -1314,7 +1354,17 @@ class ClaudeRelayService { logger.error('❌ accountId missing for queue lock in relayStreamRequestWithUsageCapture') throw new Error('accountId missing for queue lock') } - const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + // 获取账户信息以检查账户级串行队列配置 + const accountForQueue = await claudeAccountService.getAccount(accountId) + const accountConfig = accountForQueue + ? { maxConcurrency: parseInt(accountForQueue.maxConcurrency || '0', 10) } + : null + const queueResult = await userMessageQueueService.acquireQueueLock( + accountId, + null, + null, + accountConfig + ) if (!queueResult.acquired && !queueResult.skipped) { // 区分 Redis 后端错误和队列超时 const isBackendError = queueResult.error === 'queue_backend_error' @@ -1497,8 +1547,10 @@ class ClaudeRelayService { streamTransformer = null, requestOptions = {}, isDedicatedOfficialAccount = false, - onResponseStart = null // 📬 新增:收到响应头时的回调,用于提前释放队列锁 + onResponseStart = null, // 📬 新增:收到响应头时的回调,用于提前释放队列锁 + retryCount = 0 // 🔄 403 重试计数器 ) { + const maxRetries = 2 // 最大重试次数 // 获取账户信息用于统一 User-Agent const account = await claudeAccountService.getAccount(accountId) @@ -1611,6 +1663,51 @@ class ClaudeRelayService { } } + // 🔄 403 重试机制(必须在设置 res.on('data')/res.on('end') 之前处理) + // 否则重试时旧响应的 on('end') 会与新请求产生竞态条件 + if (res.statusCode === 403) { + const canRetry = + this._shouldRetryOn403(accountType) && + retryCount < maxRetries && + !responseStream.headersSent + + if (canRetry) { + logger.warn( + `🔄 [Stream] 403 error for account ${accountId}, retry ${retryCount + 1}/${maxRetries} after 2s` + ) + // 消费当前响应并销毁请求 + res.resume() + req.destroy() + + // 等待 2 秒后递归重试 + await this._sleep(2000) + + try { + // 递归调用自身进行重试 + const retryResult = await this._makeClaudeStreamRequestWithUsageCapture( + body, + accessToken, + proxyAgent, + clientHeaders, + responseStream, + usageCallback, + accountId, + accountType, + sessionHash, + streamTransformer, + requestOptions, + isDedicatedOfficialAccount, + onResponseStart, + retryCount + 1 + ) + resolve(retryResult) + } catch (retryError) { + reject(retryError) + } + return // 重要:提前返回,不设置后续的错误处理器 + } + } + // 将错误处理逻辑封装在一个异步函数中 const handleErrorResponse = async () => { if (res.statusCode === 401) { @@ -1634,8 +1731,10 @@ class ClaudeRelayService { ) } } else if (res.statusCode === 403) { + // 403 处理:走到这里说明重试已用尽或不适用重试,直接标记 blocked + // 注意:重试逻辑已在 handleErrorResponse 外部提前处理 logger.error( - `🚫 [Stream] Forbidden error (403) detected for account ${accountId}, marking as blocked` + `🚫 [Stream] Forbidden error (403) detected for account ${accountId}${retryCount > 0 ? ` after ${retryCount} retries` : ''}, marking as blocked` ) await unifiedClaudeScheduler.markAccountBlocked(accountId, accountType, sessionHash) } else if (res.statusCode === 529) { @@ -2456,28 +2555,35 @@ class ClaudeRelayService { } } + // 🔧 准备测试请求的公共逻辑(供 testAccountConnection 和 testAccountConnectionSync 共用) + async _prepareAccountForTest(accountId) { + // 获取账户信息 + const account = await claudeAccountService.getAccount(accountId) + if (!account) { + throw new Error('Account not found') + } + + // 获取有效的访问token + const accessToken = await claudeAccountService.getValidAccessToken(accountId) + if (!accessToken) { + throw new Error('Failed to get valid access token') + } + + // 获取代理配置 + const proxyAgent = await this._getProxyAgent(accountId) + + return { account, accessToken, proxyAgent } + } + // 🧪 测试账号连接(供Admin API使用,直接复用 _makeClaudeStreamRequestWithUsageCapture) - async testAccountConnection(accountId, responseStream) { - const testRequestBody = createClaudeTestPayload('claude-sonnet-4-5-20250929', { stream: true }) + async testAccountConnection(accountId, responseStream, model = 'claude-sonnet-4-5-20250929') { + const testRequestBody = createClaudeTestPayload(model, { stream: true }) try { - // 获取账户信息 - const account = await claudeAccountService.getAccount(accountId) - if (!account) { - throw new Error('Account not found') - } + const { account, accessToken, proxyAgent } = await this._prepareAccountForTest(accountId) logger.info(`🧪 Testing Claude account connection: ${account.name} (${accountId})`) - // 获取有效的访问token - const accessToken = await claudeAccountService.getValidAccessToken(accountId) - if (!accessToken) { - throw new Error('Failed to get valid access token') - } - - // 获取代理配置 - const proxyAgent = await this._getProxyAgent(accountId) - // 设置响应头 if (!responseStream.headersSent) { const existingConnection = responseStream.getHeader @@ -2526,6 +2632,125 @@ class ClaudeRelayService { } } + // 🧪 非流式测试账号连接(供定时任务使用) + // 复用流式请求方法,收集结果后返回 + async testAccountConnectionSync(accountId, model = 'claude-sonnet-4-5-20250929') { + const testRequestBody = createClaudeTestPayload(model, { stream: true }) + const startTime = Date.now() + + try { + // 使用公共方法准备测试所需的账户信息、token 和代理 + const { account, accessToken, proxyAgent } = await this._prepareAccountForTest(accountId) + + logger.info(`🧪 Testing Claude account connection (sync): ${account.name} (${accountId})`) + + // 创建一个收集器来捕获流式响应 + let responseText = '' + let capturedUsage = null + let capturedModel = model + let hasError = false + let errorMessage = '' + + // 创建模拟的响应流对象 + const mockResponseStream = { + headersSent: true, // 跳过设置响应头 + write: (data) => { + // 解析 SSE 数据 + if (typeof data === 'string' && data.startsWith('data: ')) { + try { + const jsonStr = data.replace('data: ', '').trim() + if (jsonStr && jsonStr !== '[DONE]') { + const parsed = JSON.parse(jsonStr) + // 提取文本内容 + if (parsed.type === 'content_block_delta' && parsed.delta?.text) { + responseText += parsed.delta.text + } + // 提取 usage 信息 + if (parsed.type === 'message_delta' && parsed.usage) { + capturedUsage = parsed.usage + } + // 提取模型信息 + if (parsed.type === 'message_start' && parsed.message?.model) { + capturedModel = parsed.message.model + } + // 检测错误 + if (parsed.type === 'error') { + hasError = true + errorMessage = parsed.error?.message || 'Unknown error' + } + } + } catch { + // 忽略解析错误 + } + } + return true + }, + end: () => {}, + on: () => {}, + once: () => {}, + emit: () => {}, + writable: true + } + + // 复用流式请求方法 + await this._makeClaudeStreamRequestWithUsageCapture( + testRequestBody, + accessToken, + proxyAgent, + {}, // clientHeaders - 测试不需要客户端headers + mockResponseStream, + null, // usageCallback - 测试不需要统计 + accountId, + 'claude-official', // accountType + null, // sessionHash - 测试不需要会话 + null, // streamTransformer - 不需要转换,直接解析原始格式 + {}, // requestOptions + false // isDedicatedOfficialAccount + ) + + const latencyMs = Date.now() - startTime + + if (hasError) { + logger.warn(`⚠️ Test completed with error for account: ${account.name} - ${errorMessage}`) + return { + success: false, + error: errorMessage, + latencyMs, + timestamp: new Date().toISOString() + } + } + + logger.info(`✅ Test completed for account: ${account.name} (${latencyMs}ms)`) + + return { + success: true, + message: responseText.substring(0, 200), // 截取前200字符 + latencyMs, + model: capturedModel, + usage: capturedUsage, + timestamp: new Date().toISOString() + } + } catch (error) { + const latencyMs = Date.now() - startTime + logger.error(`❌ Test account connection (sync) failed:`, error.message) + + // 提取错误详情 + let errorMessage = error.message + if (error.response) { + errorMessage = + error.response.data?.error?.message || error.response.statusText || error.message + } + + return { + success: false, + error: errorMessage, + statusCode: error.response?.status, + latencyMs, + timestamp: new Date().toISOString() + } + } + } + // 🎯 健康检查 async healthCheck() { try { @@ -2547,6 +2772,17 @@ class ClaudeRelayService { } } } + + // 🔄 判断账户是否应该在 403 错误时进行重试 + // 仅 claude-official 类型账户(OAuth 或 Setup Token 授权)需要重试 + _shouldRetryOn403(accountType) { + return accountType === 'claude-official' + } + + // ⏱️ 等待指定毫秒数 + _sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)) + } } module.exports = new ClaudeRelayService() diff --git a/src/services/userMessageQueueService.js b/src/services/userMessageQueueService.js index e35a9f64..2b4784a2 100644 --- a/src/services/userMessageQueueService.js +++ b/src/services/userMessageQueueService.js @@ -121,12 +121,23 @@ class UserMessageQueueService { * @param {string} accountId - 账户ID * @param {string} requestId - 请求ID(可选,会自动生成) * @param {number} timeoutMs - 超时时间(可选,使用配置默认值) + * @param {Object} accountConfig - 账户级配置(可选),优先级高于全局配置 + * @param {number} accountConfig.maxConcurrency - 账户级串行队列开关:>0启用,=0使用全局配置 * @returns {Promise<{acquired: boolean, requestId: string, error?: string}>} */ - async acquireQueueLock(accountId, requestId = null, timeoutMs = null) { + async acquireQueueLock(accountId, requestId = null, timeoutMs = null, accountConfig = null) { const cfg = await this.getConfig() - if (!cfg.enabled) { + // 账户级配置优先:maxConcurrency > 0 时强制启用,忽略全局开关 + let queueEnabled = cfg.enabled + if (accountConfig && accountConfig.maxConcurrency > 0) { + queueEnabled = true + logger.debug( + `📬 User message queue: account-level queue enabled for account ${accountId} (maxConcurrency=${accountConfig.maxConcurrency})` + ) + } + + if (!queueEnabled) { return { acquired: true, requestId: requestId || uuidv4(), skipped: true } } diff --git a/src/utils/featureFlags.js b/src/utils/featureFlags.js new file mode 100644 index 00000000..35802d55 --- /dev/null +++ b/src/utils/featureFlags.js @@ -0,0 +1,44 @@ +let config = {} +try { + // config/config.js 可能在某些环境不存在(例如仅拷贝了 config.example.js) + // 为保证可运行,这里做容错处理 + // eslint-disable-next-line global-require + config = require('../../config/config') +} catch (error) { + config = {} +} + +const parseBooleanEnv = (value) => { + if (typeof value === 'boolean') { + return value + } + if (typeof value !== 'string') { + return false + } + const normalized = value.trim().toLowerCase() + return normalized === 'true' || normalized === '1' || normalized === 'yes' || normalized === 'on' +} + +/** + * 是否允许执行“余额脚本”(安全开关) + * 默认开启,便于保持现有行为;如需禁用请显式设置 BALANCE_SCRIPT_ENABLED=false(环境变量优先) + */ +const isBalanceScriptEnabled = () => { + if ( + process.env.BALANCE_SCRIPT_ENABLED !== undefined && + process.env.BALANCE_SCRIPT_ENABLED !== '' + ) { + return parseBooleanEnv(process.env.BALANCE_SCRIPT_ENABLED) + } + + const fromConfig = + config?.accountBalance?.enableBalanceScript ?? + config?.features?.balanceScriptEnabled ?? + config?.security?.enableBalanceScript + + return typeof fromConfig === 'boolean' ? fromConfig : true +} + +module.exports = { + isBalanceScriptEnabled +} diff --git a/src/utils/warmupInterceptor.js b/src/utils/warmupInterceptor.js new file mode 100644 index 00000000..430d622d --- /dev/null +++ b/src/utils/warmupInterceptor.js @@ -0,0 +1,202 @@ +'use strict' + +const { v4: uuidv4 } = require('uuid') + +/** + * 预热请求拦截器 + * 检测并拦截低价值请求(标题生成、Warmup等),直接返回模拟响应 + */ + +/** + * 检测是否为预热请求 + * @param {Object} body - 请求体 + * @returns {boolean} + */ +function isWarmupRequest(body) { + if (!body) { + return false + } + + // 检查 messages + if (body.messages && Array.isArray(body.messages)) { + for (const msg of body.messages) { + // 处理 content 为数组的情况 + if (Array.isArray(msg.content)) { + for (const content of msg.content) { + if (content.type === 'text' && typeof content.text === 'string') { + if (isTitleOrWarmupText(content.text)) { + return true + } + } + } + } + // 处理 content 为字符串的情况 + if (typeof msg.content === 'string') { + if (isTitleOrWarmupText(msg.content)) { + return true + } + } + } + } + + // 检查 system prompt + if (body.system) { + const systemText = extractSystemText(body.system) + if (isTitleExtractionSystemPrompt(systemText)) { + return true + } + } + + return false +} + +/** + * 检查文本是否为标题生成或Warmup请求 + */ +function isTitleOrWarmupText(text) { + if (!text) { + return false + } + return ( + text.includes('Please write a 5-10 word title for the following conversation:') || + text === 'Warmup' + ) +} + +/** + * 检查system prompt是否为标题提取类型 + */ +function isTitleExtractionSystemPrompt(systemText) { + if (!systemText) { + return false + } + return systemText.includes( + 'nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title' + ) +} + +/** + * 从system字段提取文本 + */ +function extractSystemText(system) { + if (typeof system === 'string') { + return system + } + if (Array.isArray(system)) { + return system.map((s) => (typeof s === 'object' ? s.text || '' : String(s))).join('') + } + return '' +} + +/** + * 生成模拟的非流式响应 + * @param {string} model - 模型名称 + * @returns {Object} + */ +function buildMockWarmupResponse(model) { + return { + id: `msg_warmup_${uuidv4().replace(/-/g, '').slice(0, 20)}`, + type: 'message', + role: 'assistant', + content: [{ type: 'text', text: 'New Conversation' }], + model: model || 'claude-3-5-sonnet-20241022', + stop_reason: 'end_turn', + stop_sequence: null, + usage: { + input_tokens: 10, + output_tokens: 2 + } + } +} + +/** + * 发送模拟的流式响应 + * @param {Object} res - Express response对象 + * @param {string} model - 模型名称 + */ +function sendMockWarmupStream(res, model) { + const effectiveModel = model || 'claude-3-5-sonnet-20241022' + const messageId = `msg_warmup_${uuidv4().replace(/-/g, '').slice(0, 20)}` + + const events = [ + { + event: 'message_start', + data: { + message: { + content: [], + id: messageId, + model: effectiveModel, + role: 'assistant', + stop_reason: null, + stop_sequence: null, + type: 'message', + usage: { input_tokens: 10, output_tokens: 0 } + }, + type: 'message_start' + } + }, + { + event: 'content_block_start', + data: { + content_block: { text: '', type: 'text' }, + index: 0, + type: 'content_block_start' + } + }, + { + event: 'content_block_delta', + data: { + delta: { text: 'New', type: 'text_delta' }, + index: 0, + type: 'content_block_delta' + } + }, + { + event: 'content_block_delta', + data: { + delta: { text: ' Conversation', type: 'text_delta' }, + index: 0, + type: 'content_block_delta' + } + }, + { + event: 'content_block_stop', + data: { index: 0, type: 'content_block_stop' } + }, + { + event: 'message_delta', + data: { + delta: { stop_reason: 'end_turn', stop_sequence: null }, + type: 'message_delta', + usage: { input_tokens: 10, output_tokens: 2 } + } + }, + { + event: 'message_stop', + data: { type: 'message_stop' } + } + ] + + let index = 0 + const sendNext = () => { + if (index >= events.length) { + res.end() + return + } + + const { event, data } = events[index] + res.write(`event: ${event}\ndata: ${JSON.stringify(data)}\n\n`) + index++ + + // 模拟网络延迟 + setTimeout(sendNext, 20) + } + + sendNext() +} + +module.exports = { + isWarmupRequest, + buildMockWarmupResponse, + sendMockWarmupStream +} diff --git a/tests/accountBalanceService.test.js b/tests/accountBalanceService.test.js new file mode 100644 index 00000000..c2a9c3a8 --- /dev/null +++ b/tests/accountBalanceService.test.js @@ -0,0 +1,218 @@ +// Mock logger,避免测试输出污染控制台 +jest.mock('../src/utils/logger', () => ({ + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn() +})) + +const accountBalanceServiceModule = require('../src/services/accountBalanceService') + +const { AccountBalanceService } = accountBalanceServiceModule + +describe('AccountBalanceService', () => { + const originalBalanceScriptEnabled = process.env.BALANCE_SCRIPT_ENABLED + + afterEach(() => { + if (originalBalanceScriptEnabled === undefined) { + delete process.env.BALANCE_SCRIPT_ENABLED + } else { + process.env.BALANCE_SCRIPT_ENABLED = originalBalanceScriptEnabled + } + }) + + const mockLogger = { + debug: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + error: jest.fn() + } + + const buildMockRedis = () => ({ + getLocalBalance: jest.fn().mockResolvedValue(null), + setLocalBalance: jest.fn().mockResolvedValue(undefined), + getAccountBalance: jest.fn().mockResolvedValue(null), + setAccountBalance: jest.fn().mockResolvedValue(undefined), + deleteAccountBalance: jest.fn().mockResolvedValue(undefined), + getBalanceScriptConfig: jest.fn().mockResolvedValue(null), + getAccountUsageStats: jest.fn().mockResolvedValue({ + total: { requests: 10 }, + daily: { requests: 2, cost: 20 }, + monthly: { requests: 5 } + }), + getDateInTimezone: (date) => new Date(date.getTime() + 8 * 3600 * 1000) + }) + + it('should normalize platform aliases', () => { + const service = new AccountBalanceService({ redis: buildMockRedis(), logger: mockLogger }) + expect(service.normalizePlatform('claude-official')).toBe('claude') + expect(service.normalizePlatform('azure-openai')).toBe('azure_openai') + expect(service.normalizePlatform('gemini-api')).toBe('gemini-api') + }) + + it('should build local quota/balance from dailyQuota and local dailyCost', async () => { + const mockRedis = buildMockRedis() + const service = new AccountBalanceService({ redis: mockRedis, logger: mockLogger }) + + service._computeMonthlyCost = jest.fn().mockResolvedValue(30) + service._computeTotalCost = jest.fn().mockResolvedValue(123.45) + + const account = { id: 'acct-1', name: 'A', dailyQuota: '100', quotaResetTime: '00:00' } + const result = await service._getAccountBalanceForAccount(account, 'claude-console', { + queryApi: false, + useCache: true + }) + + expect(result.success).toBe(true) + expect(result.data.source).toBe('local') + expect(result.data.balance.amount).toBeCloseTo(80, 6) + expect(result.data.quota.percentage).toBeCloseTo(20, 6) + expect(result.data.statistics.totalCost).toBeCloseTo(123.45, 6) + expect(mockRedis.setLocalBalance).toHaveBeenCalled() + }) + + it('should use cached balance when account has no dailyQuota', async () => { + const mockRedis = buildMockRedis() + mockRedis.getAccountBalance.mockResolvedValue({ + status: 'success', + balance: 12.34, + currency: 'USD', + quota: null, + errorMessage: '', + lastRefreshAt: '2025-01-01T00:00:00Z', + ttlSeconds: 120 + }) + + const service = new AccountBalanceService({ redis: mockRedis, logger: mockLogger }) + service._computeMonthlyCost = jest.fn().mockResolvedValue(0) + service._computeTotalCost = jest.fn().mockResolvedValue(0) + + const account = { id: 'acct-2', name: 'B' } + const result = await service._getAccountBalanceForAccount(account, 'openai', { + queryApi: false, + useCache: true + }) + + expect(result.data.source).toBe('cache') + expect(result.data.balance.amount).toBeCloseTo(12.34, 6) + expect(result.data.lastRefreshAt).toBe('2025-01-01T00:00:00Z') + }) + + it('should not cache provider errors and fallback to local when queryApi=true', async () => { + const mockRedis = buildMockRedis() + const service = new AccountBalanceService({ redis: mockRedis, logger: mockLogger }) + + service._computeMonthlyCost = jest.fn().mockResolvedValue(0) + service._computeTotalCost = jest.fn().mockResolvedValue(0) + + service.registerProvider('openai', { + queryBalance: () => { + throw new Error('boom') + } + }) + + const account = { id: 'acct-3', name: 'C' } + const result = await service._getAccountBalanceForAccount(account, 'openai', { + queryApi: true, + useCache: false + }) + + expect(mockRedis.setAccountBalance).not.toHaveBeenCalled() + expect(result.data.source).toBe('local') + expect(result.data.status).toBe('error') + expect(result.data.error).toBe('boom') + }) + + it('should ignore script config when balance script is disabled', async () => { + process.env.BALANCE_SCRIPT_ENABLED = 'false' + + const mockRedis = buildMockRedis() + mockRedis.getBalanceScriptConfig.mockResolvedValue({ + scriptBody: '({ request: { url: "http://example.com" }, extractor: function(){ return {} } })' + }) + + const service = new AccountBalanceService({ redis: mockRedis, logger: mockLogger }) + service._computeMonthlyCost = jest.fn().mockResolvedValue(0) + service._computeTotalCost = jest.fn().mockResolvedValue(0) + + const provider = { queryBalance: jest.fn().mockResolvedValue({ balance: 1, currency: 'USD' }) } + service.registerProvider('openai', provider) + + const scriptSpy = jest.spyOn(service, '_getBalanceFromScript') + + const account = { id: 'acct-script-off', name: 'S' } + const result = await service._getAccountBalanceForAccount(account, 'openai', { + queryApi: true, + useCache: false + }) + + expect(provider.queryBalance).toHaveBeenCalled() + expect(scriptSpy).not.toHaveBeenCalled() + expect(result.data.source).toBe('api') + }) + + it('should prefer script when configured and enabled', async () => { + process.env.BALANCE_SCRIPT_ENABLED = 'true' + + const mockRedis = buildMockRedis() + mockRedis.getBalanceScriptConfig.mockResolvedValue({ + scriptBody: '({ request: { url: "http://example.com" }, extractor: function(){ return {} } })' + }) + + const service = new AccountBalanceService({ redis: mockRedis, logger: mockLogger }) + service._computeMonthlyCost = jest.fn().mockResolvedValue(0) + service._computeTotalCost = jest.fn().mockResolvedValue(0) + + const provider = { queryBalance: jest.fn().mockResolvedValue({ balance: 2, currency: 'USD' }) } + service.registerProvider('openai', provider) + + jest.spyOn(service, '_getBalanceFromScript').mockResolvedValue({ + status: 'success', + balance: 3, + currency: 'USD', + quota: null, + queryMethod: 'script', + rawData: { ok: true }, + lastRefreshAt: '2025-01-01T00:00:00Z', + errorMessage: '' + }) + + const account = { id: 'acct-script-on', name: 'T' } + const result = await service._getAccountBalanceForAccount(account, 'openai', { + queryApi: true, + useCache: false + }) + + expect(provider.queryBalance).not.toHaveBeenCalled() + expect(result.data.source).toBe('api') + expect(result.data.balance.amount).toBeCloseTo(3, 6) + expect(result.data.lastRefreshAt).toBe('2025-01-01T00:00:00Z') + }) + + it('should count low balance once per account in summary', async () => { + const mockRedis = buildMockRedis() + const service = new AccountBalanceService({ redis: mockRedis, logger: mockLogger }) + + service.getSupportedPlatforms = () => ['claude-console'] + service.getAllAccountsByPlatform = async () => [{ id: 'acct-4', name: 'D' }] + service._getAccountBalanceForAccount = async () => ({ + success: true, + data: { + accountId: 'acct-4', + platform: 'claude-console', + balance: { amount: 5, currency: 'USD', formattedAmount: '$5.00' }, + quota: { percentage: 95 }, + statistics: { totalCost: 1 }, + source: 'local', + lastRefreshAt: '2025-01-01T00:00:00Z', + cacheExpiresAt: null, + status: 'success', + error: null + } + }) + + const summary = await service.getBalanceSummary() + expect(summary.lowBalanceCount).toBe(1) + expect(summary.platforms['claude-console'].lowBalanceCount).toBe(1) + }) +}) diff --git a/web/admin-spa/src/components/accounts/AccountBalanceScriptModal.vue b/web/admin-spa/src/components/accounts/AccountBalanceScriptModal.vue new file mode 100644 index 00000000..17f2be00 --- /dev/null +++ b/web/admin-spa/src/components/accounts/AccountBalanceScriptModal.vue @@ -0,0 +1,302 @@ + + + + + diff --git a/web/admin-spa/src/components/accounts/AccountForm.vue b/web/admin-spa/src/components/accounts/AccountForm.vue index 7bd4b883..1d185fa4 100644 --- a/web/admin-spa/src/components/accounts/AccountForm.vue +++ b/web/admin-spa/src/components/accounts/AccountForm.vue @@ -1662,6 +1662,47 @@ + +
+ +
+ + +
+ +
+
+ +
+ +
+ + +
+ +
+
+ +
+ + + +
+ + +
@@ -1238,6 +1282,15 @@ 测试 +
+ +
+

余额/配额

+ +
+ +
+
+
@@ -1707,6 +1780,15 @@ 测试 + +
+ +
+
+
+
+

+ 账户余额/配额 +

+

+ {{ formatCurrencyUsd(balanceSummary.totalBalance || 0) }} +

+

+ 低余额: {{ balanceSummary.lowBalanceCount || 0 }} | 总成本: + {{ formatCurrencyUsd(balanceSummary.totalCost || 0) }} +

+
+
+ +
+
+ +
+

+ 更新时间: {{ formatLastUpdate(balanceSummaryUpdatedAt) }} +

+ +
+
+ +
+
+

低余额账户

+ + {{ lowBalanceAccounts.length }} 个 + +
+ +
+ 正在加载... +
+
+ 全部正常 +
+
+
+
+
+ {{ account.name || account.accountId }} +
+ + {{ getBalancePlatformLabel(account.platform) }} + +
+
+ 余额: {{ account.balance.formattedAmount }} + 今日成本: {{ formatCurrencyUsd(account.statistics?.dailyCost || 0) }} +
+
+
+ 配额使用 + + {{ account.quota.percentage.toFixed(1) }}% + +
+
+
+
+
+
+
+
+
+
{ + const map = { + claude: 'Claude', + 'claude-console': 'Claude Console', + gemini: 'Gemini', + 'gemini-api': 'Gemini API', + openai: 'OpenAI', + 'openai-responses': 'OpenAI Responses', + azure_openai: 'Azure OpenAI', + bedrock: 'Bedrock', + droid: 'Droid', + ccr: 'CCR' + } + return map[platform] || platform +} + +const lowBalanceAccounts = computed(() => { + const result = [] + const platforms = balanceSummary.value?.platforms || {} + + Object.entries(platforms).forEach(([platform, data]) => { + const list = Array.isArray(data?.accounts) ? data.accounts : [] + list.forEach((entry) => { + const accountData = entry?.data + if (!accountData) return + + const amount = accountData.balance?.amount + const percentage = accountData.quota?.percentage + + const isLowBalance = typeof amount === 'number' && amount < 10 + const isHighUsage = typeof percentage === 'number' && percentage > 90 + + if (isLowBalance || isHighUsage) { + result.push({ + ...accountData, + name: entry?.name || accountData.accountId, + platform: accountData.platform || platform + }) + } + }) + }) + + return result +}) + +const formatCurrencyUsd = (amount) => { + const value = Number(amount) + if (!Number.isFinite(value)) return '$0.00' + if (value >= 1) return `$${value.toFixed(2)}` + if (value >= 0.01) return `$${value.toFixed(3)}` + return `$${value.toFixed(6)}` +} + +const formatLastUpdate = (isoString) => { + if (!isoString) return '未知' + const date = new Date(isoString) + if (Number.isNaN(date.getTime())) return '未知' + return date.toLocaleTimeString('zh-CN', { hour: '2-digit', minute: '2-digit' }) +} + +const loadBalanceSummary = async () => { + loadingBalanceSummary.value = true + try { + const response = await apiClient.get('/admin/accounts/balance/summary') + if (response?.success) { + balanceSummary.value = response.data || { + totalBalance: 0, + totalCost: 0, + lowBalanceCount: 0, + platforms: {} + } + balanceSummaryUpdatedAt.value = new Date().toISOString() + } + } catch (error) { + console.debug('加载余额汇总失败:', error) + showToast('加载余额汇总失败', 'error') + } finally { + loadingBalanceSummary.value = false + } +} + // 自动刷新相关 const autoRefreshEnabled = ref(false) const autoRefreshInterval = ref(30) // 秒 @@ -1488,7 +1680,7 @@ async function refreshAllData() { isRefreshing.value = true try { - await Promise.all([loadDashboardData(), refreshChartsData()]) + await Promise.all([loadDashboardData(), refreshChartsData(), loadBalanceSummary()]) } finally { isRefreshing.value = false }