diff --git a/.env.example b/.env.example index 704d0a8a..eeb10de0 100644 --- a/.env.example +++ b/.env.example @@ -61,6 +61,9 @@ PROXY_USE_IPV4=true # ⏱️ 请求超时配置 REQUEST_TIMEOUT=600000 # 请求超时设置(毫秒),默认10分钟 +# 🔧 请求体大小配置 +REQUEST_MAX_SIZE_MB=60 + # 📈 使用限制 DEFAULT_TOKEN_LIMIT=1000000 @@ -75,6 +78,8 @@ TOKEN_USAGE_RETENTION=2592000000 HEALTH_CHECK_INTERVAL=60000 TIMEZONE_OFFSET=8 # UTC偏移小时数,默认+8(中国时区) METRICS_WINDOW=5 # 实时指标统计窗口(分钟),可选1-60,默认5分钟 +# 启动时清理残留的并发排队计数器(默认true,多实例部署时建议设为false) +CLEAR_CONCURRENCY_QUEUES_ON_STARTUP=true # 🎨 Web 界面配置 WEB_TITLE=Claude Relay Service diff --git a/CLAUDE.md b/CLAUDE.md index 1eac1b03..892b4758 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -22,6 +22,7 @@ Claude Relay Service 是一个多平台 AI API 中转服务,支持 **Claude ( - **权限控制**: API Key支持权限配置(all/claude/gemini/openai等),控制可访问的服务类型 - **客户端限制**: 基于User-Agent的客户端识别和限制,支持ClaudeCode、Gemini-CLI等预定义客户端 - **模型黑名单**: 支持API Key级别的模型访问限制 +- **并发请求排队**: 当API Key并发数超限时,请求进入队列等待而非立即返回429,支持配置最大排队数、超时时间,适用于Claude Code Agent并行工具调用场景 ### 主要服务组件 @@ -60,6 +61,7 @@ Claude Relay Service 是一个多平台 AI API 中转服务,支持 **Claude ( - **apiKeyService.js**: API Key管理,验证、限流、使用统计、成本计算 - **userService.js**: 用户管理系统,支持用户注册、登录、API Key管理 +- **userMessageQueueService.js**: 用户消息串行队列,防止同账户并发用户消息触发限流 - **pricingService.js**: 定价服务,模型价格管理和成本计算 - **costInitService.js**: 成本数据初始化服务 - **webhookService.js**: Webhook通知服务 @@ -185,12 +187,17 @@ npm run service:stop # 停止服务 - `CLAUDE_OVERLOAD_HANDLING_MINUTES`: Claude 529错误处理持续时间(分钟,0表示禁用) - `STICKY_SESSION_TTL_HOURS`: 粘性会话TTL(小时,默认1) - `STICKY_SESSION_RENEWAL_THRESHOLD_MINUTES`: 粘性会话续期阈值(分钟,默认0) +- `USER_MESSAGE_QUEUE_ENABLED`: 启用用户消息串行队列(默认false) +- `USER_MESSAGE_QUEUE_DELAY_MS`: 用户消息请求间隔(毫秒,默认200) +- `USER_MESSAGE_QUEUE_TIMEOUT_MS`: 队列等待超时(毫秒,默认5000,锁持有时间短无需长等待) +- `USER_MESSAGE_QUEUE_LOCK_TTL_MS`: 锁TTL(毫秒,默认5000,请求发送后立即释放无需长TTL) - `METRICS_WINDOW`: 实时指标统计窗口(分钟,1-60,默认5) - `MAX_API_KEYS_PER_USER`: 每用户最大API Key数量(默认1) - `ALLOW_USER_DELETE_API_KEYS`: 允许用户删除自己的API Keys(默认false) - `DEBUG_HTTP_TRAFFIC`: 启用HTTP请求/响应调试日志(默认false,仅开发环境) - `PROXY_USE_IPV4`: 代理使用IPv4(默认true) - `REQUEST_TIMEOUT`: 请求超时时间(毫秒,默认600000即10分钟) +- `CLEAR_CONCURRENCY_QUEUES_ON_STARTUP`: 启动时清理残留的并发排队计数器(默认true,多实例部署时建议设为false) #### AWS Bedrock配置(可选) - `CLAUDE_CODE_USE_BEDROCK`: 启用Bedrock(设置为1启用) @@ -337,6 +344,35 @@ npm run setup # 自动生成密钥并创建管理员账户 11. **速率限制未清理**: rateLimitCleanupService每5分钟自动清理过期限流状态 12. **成本统计不准确**: 运行 `npm run init:costs` 初始化成本数据,检查pricingService是否正确加载模型价格 13. **缓存命中率低**: 查看缓存监控统计,调整LRU缓存大小配置 +14. **用户消息队列超时**: 优化后锁持有时间已从分钟级降到毫秒级(请求发送后立即释放),默认 `USER_MESSAGE_QUEUE_TIMEOUT_MS=5000` 已足够。如仍有超时,检查网络延迟或禁用此功能(`USER_MESSAGE_QUEUE_ENABLED=false`) +15. **并发请求排队问题**: + - 排队超时:检查 `concurrentRequestQueueTimeoutMs` 配置是否合理(默认10秒) + - 排队数过多:调整 `concurrentRequestQueueMaxSize` 和 `concurrentRequestQueueMaxSizeMultiplier` + - 查看排队统计:访问 `/admin/concurrency-queue/stats` 接口查看 entered/success/timeout/cancelled/socket_changed/rejected_overload 统计 + - 排队计数泄漏:系统重启时自动清理,或访问 `/admin/concurrency-queue` DELETE 接口手动清理 + - Socket 身份验证失败:查看 `socket_changed` 统计,如果频繁发生,检查代理配置或客户端连接稳定性 + - 健康检查拒绝:查看 `rejected_overload` 统计,表示队列过载时的快速失败次数 + +### 代理配置要求(并发请求排队) + +使用并发请求排队功能时,需要正确配置代理(如 Nginx)的超时参数: + +- **推荐配置**: `proxy_read_timeout >= max(2 × concurrentRequestQueueTimeoutMs, 60s)` + - 当前默认排队超时 10 秒,Nginx 默认 `proxy_read_timeout = 60s` 已满足要求 + - 如果调整排队超时到 60 秒,推荐代理超时 ≥ 120 秒 +- **Nginx 配置示例**: + ```nginx + location /api/ { + proxy_read_timeout 120s; # 排队超时 60s 时推荐 120s + proxy_connect_timeout 10s; + # ...其他配置 + } + ``` +- **企业防火墙环境**: + - 某些企业防火墙可能静默关闭长时间无数据的连接(20-40 秒) + - 如遇此问题,联系网络管理员调整空闲连接超时策略 + - 或降低 `concurrentRequestQueueTimeoutMs` 配置 +- **后续升级说明**: 如有需要,后续版本可能提供可选的轻量级心跳机制 ### 调试工具 @@ -449,6 +485,15 @@ npm run setup # 自动生成密钥并创建管理员账户 - **缓存优化**: 多层LRU缓存(解密缓存、账户缓存),全局缓存监控和统计 - **成本追踪**: 实时token使用统计(input/output/cache_create/cache_read)和成本计算(基于pricingService) - **并发控制**: Redis Sorted Set实现的并发计数,支持自动过期清理 +- **并发请求排队**: 当API Key并发超限时,请求进入队列等待而非直接返回429 + - **工作原理**: 采用「先占后检查」模式,每次轮询尝试占位,超限则释放继续等待 + - **指数退避**: 初始200ms,指数增长至最大2秒,带±20%抖动防惊群效应 + - **智能清理**: 排队计数有TTL保护(超时+30秒),进程崩溃也能自动清理 + - **Socket身份验证**: 使用UUID token + socket对象引用双重验证,避免HTTP Keep-Alive连接复用导致的身份混淆 + - **健康检查**: P90等待时间超过阈值时快速失败(返回429),避免新请求在过载时继续排队 + - **配置参数**: `concurrentRequestQueueEnabled`(默认false)、`concurrentRequestQueueMaxSize`(默认3)、`concurrentRequestQueueMaxSizeMultiplier`(默认0)、`concurrentRequestQueueTimeoutMs`(默认10秒)、`concurrentRequestQueueMaxRedisFailCount`(默认5)、`concurrentRequestQueueHealthCheckEnabled`(默认true)、`concurrentRequestQueueHealthThreshold`(默认0.8) + - **最大排队数**: max(固定值, 并发限制×倍数),例如并发限制=10、倍数=2时最大排队数=20 + - **适用场景**: Claude Code Agent并行工具调用、批量请求处理 - **客户端识别**: 基于User-Agent的客户端限制,支持预定义客户端(ClaudeCode、Gemini-CLI等) - **错误处理**: 529错误自动标记账户过载状态,配置时长内自动排除该账户 @@ -508,8 +553,16 @@ npm run setup # 自动生成密钥并创建管理员账户 - `overload:{accountId}` - 账户过载状态(529错误) - **并发控制**: - `concurrency:{accountId}` - Redis Sorted Set实现的并发计数 +- **并发请求排队**: + - `concurrency:queue:{apiKeyId}` - API Key级别的排队计数器(TTL由 `concurrentRequestQueueTimeoutMs` + 30秒缓冲决定) + - `concurrency:queue:stats:{apiKeyId}` - 排队统计(entered/success/timeout/cancelled) + - `concurrency:queue:wait_times:{apiKeyId}` - 按API Key的等待时间记录(用于P50/P90/P99计算) + - `concurrency:queue:wait_times:global` - 全局等待时间记录 - **Webhook配置**: - `webhook_config:{id}` - Webhook配置 +- **用户消息队列**: + - `user_msg_queue_lock:{accountId}` - 用户消息队列锁(当前持有者requestId) + - `user_msg_queue_last:{accountId}` - 上次请求完成时间戳(用于延迟计算) - **系统信息**: - `system_info` - 系统状态缓存 - `model_pricing` - 模型价格数据(pricingService) diff --git a/VERSION b/VERSION index 40ff1f13..df06a0ad 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.226 +1.1.237 diff --git a/config/config.example.js b/config/config.example.js index 5395142a..9cf26002 100644 --- a/config/config.example.js +++ b/config/config.example.js @@ -203,6 +203,15 @@ const config = { development: { debug: process.env.DEBUG === 'true', hotReload: process.env.HOT_RELOAD === 'true' + }, + + // 📬 用户消息队列配置 + // 优化说明:锁在请求发送成功后立即释放(而非请求完成后),因为 Claude API 限流基于请求发送时刻计算 + userMessageQueue: { + enabled: process.env.USER_MESSAGE_QUEUE_ENABLED === 'true', // 默认关闭 + delayMs: parseInt(process.env.USER_MESSAGE_QUEUE_DELAY_MS) || 200, // 请求间隔(毫秒) + timeoutMs: parseInt(process.env.USER_MESSAGE_QUEUE_TIMEOUT_MS) || 5000, // 队列等待超时(毫秒),锁持有时间短,无需长等待 + lockTtlMs: parseInt(process.env.USER_MESSAGE_QUEUE_LOCK_TTL_MS) || 5000 // 锁TTL(毫秒),5秒足以覆盖请求发送 } } diff --git a/docker-compose.yml b/docker-compose.yml index 79b9afb8..d8f78a24 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -21,6 +21,9 @@ services: - PORT=3000 - HOST=0.0.0.0 + # 🔧 请求体大小配置 + - REQUEST_MAX_SIZE_MB=60 + # 🔐 安全配置(必填) - JWT_SECRET=${JWT_SECRET} # 必填:至少32字符的随机字符串 - ENCRYPTION_KEY=${ENCRYPTION_KEY} # 必填:32字符的加密密钥 diff --git a/package-lock.json b/package-lock.json index c6dccd11..d9ebcff0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -26,6 +26,7 @@ "ioredis": "^5.3.2", "ldapjs": "^3.0.7", "morgan": "^1.10.0", + "node-cron": "^4.2.1", "nodemailer": "^7.0.6", "ora": "^5.4.1", "rate-limiter-flexible": "^5.0.5", @@ -891,7 +892,6 @@ "integrity": "sha512-2BCOP7TN8M+gVDj7/ht3hsaO/B/n5oDbiAyyvnRlNOs+u1o+JWNYTQrmpuNp1/Wq2gcFrI01JAW+paEKDMx/CA==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@babel/code-frame": "^7.27.1", "@babel/generator": "^7.28.3", @@ -3000,7 +3000,6 @@ "integrity": "sha512-yCAeZl7a0DxgNVteXFHt9+uyFbqXGy/ShC4BlcHkoE0AfGXYv/BUiplV72DjMYXHDBXFjhvr6DD1NiRVfB4j8g==", "devOptional": true, "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~6.21.0" } @@ -3082,7 +3081,6 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -3538,7 +3536,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "caniuse-lite": "^1.0.30001737", "electron-to-chromium": "^1.5.211", @@ -4426,7 +4423,6 @@ "deprecated": "This version is no longer supported. Please see https://eslint.org/version-support for other options.", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.6.1", @@ -4483,7 +4479,6 @@ "integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==", "dev": true, "license": "MIT", - "peer": true, "bin": { "eslint-config-prettier": "bin/cli.js" }, @@ -7034,6 +7029,15 @@ "node": ">= 0.6" } }, + "node_modules/node-cron": { + "version": "4.2.1", + "resolved": "https://registry.npmmirror.com/node-cron/-/node-cron-4.2.1.tgz", + "integrity": "sha512-lgimEHPE/QDgFlywTd8yTR61ptugX3Qer29efeyWw2rv259HtGBNn1vZVmp8lB9uo9wC0t/AT4iGqXxia+CJFg==", + "license": "ISC", + "engines": { + "node": ">=6.0.0" + } + }, "node_modules/node-domexception": { "version": "1.0.0", "resolved": "https://registry.npmmirror.com/node-domexception/-/node-domexception-1.0.0.tgz", @@ -7582,7 +7586,6 @@ "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", "dev": true, "license": "MIT", - "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -9101,7 +9104,6 @@ "resolved": "https://registry.npmmirror.com/winston/-/winston-3.17.0.tgz", "integrity": "sha512-DLiFIXYC5fMPxaRg832S6F5mJYvePtmO5G9v9IgUFPhXm9/GkXarH/TUrBAVzhTCzAj9anE/+GjrgXp/54nOgw==", "license": "MIT", - "peer": true, "dependencies": { "@colors/colors": "^1.6.0", "@dabh/diagnostics": "^2.0.2", diff --git a/package.json b/package.json index 2b7ffa25..6ef88e60 100644 --- a/package.json +++ b/package.json @@ -65,6 +65,7 @@ "ioredis": "^5.3.2", "ldapjs": "^3.0.7", "morgan": "^1.10.0", + "node-cron": "^4.2.1", "nodemailer": "^7.0.6", "ora": "^5.4.1", "rate-limiter-flexible": "^5.0.5", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index dafee4e7..9e8dc0fb 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -59,6 +59,9 @@ importers: morgan: specifier: ^1.10.0 version: 1.10.1 + node-cron: + specifier: ^4.2.1 + version: 4.2.1 nodemailer: specifier: ^7.0.6 version: 7.0.11 @@ -108,6 +111,9 @@ importers: prettier: specifier: ^3.6.2 version: 3.7.4 + prettier-plugin-tailwindcss: + specifier: ^0.7.2 + version: 0.7.2(prettier@3.7.4) supertest: specifier: ^6.3.3 version: 6.3.4 @@ -2144,6 +2150,10 @@ packages: resolution: {integrity: sha512-myRT3DiWPHqho5PrJaIRyaMv2kgYf0mUVgBNOYMuCH5Ki1yEiQaf/ZJuQ62nvpc44wL5WDbTX7yGJi1Neevw8w==} engines: {node: '>= 0.6'} + node-cron@4.2.1: + resolution: {integrity: sha512-lgimEHPE/QDgFlywTd8yTR61ptugX3Qer29efeyWw2rv259HtGBNn1vZVmp8lB9uo9wC0t/AT4iGqXxia+CJFg==} + engines: {node: '>=6.0.0'} + node-domexception@1.0.0: resolution: {integrity: sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==} engines: {node: '>=10.5.0'} @@ -2302,6 +2312,61 @@ packages: resolution: {integrity: sha512-GbK2cP9nraSSUF9N2XwUwqfzlAFlMNYYl+ShE/V+H8a9uNl/oUqB1w2EL54Jh0OlyRSd8RfWYJ3coVS4TROP2w==} engines: {node: '>=6.0.0'} + prettier-plugin-tailwindcss@0.7.2: + resolution: {integrity: sha512-LkphyK3Fw+q2HdMOoiEHWf93fNtYJwfamoKPl7UwtjFQdei/iIBoX11G6j706FzN3ymX9mPVi97qIY8328vdnA==} + engines: {node: '>=20.19'} + peerDependencies: + '@ianvs/prettier-plugin-sort-imports': '*' + '@prettier/plugin-hermes': '*' + '@prettier/plugin-oxc': '*' + '@prettier/plugin-pug': '*' + '@shopify/prettier-plugin-liquid': '*' + '@trivago/prettier-plugin-sort-imports': '*' + '@zackad/prettier-plugin-twig': '*' + prettier: ^3.0 + prettier-plugin-astro: '*' + prettier-plugin-css-order: '*' + prettier-plugin-jsdoc: '*' + prettier-plugin-marko: '*' + prettier-plugin-multiline-arrays: '*' + prettier-plugin-organize-attributes: '*' + prettier-plugin-organize-imports: '*' + prettier-plugin-sort-imports: '*' + prettier-plugin-svelte: '*' + peerDependenciesMeta: + '@ianvs/prettier-plugin-sort-imports': + optional: true + '@prettier/plugin-hermes': + optional: true + '@prettier/plugin-oxc': + optional: true + '@prettier/plugin-pug': + optional: true + '@shopify/prettier-plugin-liquid': + optional: true + '@trivago/prettier-plugin-sort-imports': + optional: true + '@zackad/prettier-plugin-twig': + optional: true + prettier-plugin-astro: + optional: true + prettier-plugin-css-order: + optional: true + prettier-plugin-jsdoc: + optional: true + prettier-plugin-marko: + optional: true + prettier-plugin-multiline-arrays: + optional: true + prettier-plugin-organize-attributes: + optional: true + prettier-plugin-organize-imports: + optional: true + prettier-plugin-sort-imports: + optional: true + prettier-plugin-svelte: + optional: true + prettier@3.7.4: resolution: {integrity: sha512-v6UNi1+3hSlVvv8fSaoUbggEM5VErKmmpGA7Pl3HF8V6uKY7rvClBOJlH6yNwQtfTueNkGVpOv/mtWL9L4bgRA==} engines: {node: '>=14'} @@ -5692,6 +5757,8 @@ snapshots: negotiator@0.6.4: {} + node-cron@4.2.1: {} + node-domexception@1.0.0: {} node-fetch@3.3.2: @@ -5840,6 +5907,10 @@ snapshots: dependencies: fast-diff: 1.3.0 + prettier-plugin-tailwindcss@0.7.2(prettier@3.7.4): + dependencies: + prettier: 3.7.4 + prettier@3.7.4: {} pretty-format@29.7.0: diff --git a/src/app.js b/src/app.js index 77047247..41edc483 100644 --- a/src/app.js +++ b/src/app.js @@ -584,6 +584,20 @@ class Application { // 使用 Lua 脚本批量清理所有过期项 for (const key of keys) { + // 跳过非 Sorted Set 类型的键(这些键有各自的清理逻辑) + // - concurrency:queue:stats:* 是 Hash 类型 + // - concurrency:queue:wait_times:* 是 List 类型 + // - concurrency:queue:* (不含stats/wait_times) 是 String 类型 + if ( + key.startsWith('concurrency:queue:stats:') || + key.startsWith('concurrency:queue:wait_times:') || + (key.startsWith('concurrency:queue:') && + !key.includes(':stats:') && + !key.includes(':wait_times:')) + ) { + continue + } + try { const cleaned = await redis.client.eval( ` @@ -625,6 +639,41 @@ class Application { }, 60000) // 每分钟执行一次 logger.info('🔢 Concurrency cleanup task started (running every 1 minute)') + + // 📬 启动用户消息队列服务 + const userMessageQueueService = require('./services/userMessageQueueService') + // 先清理服务重启后残留的锁,防止旧锁阻塞新请求 + userMessageQueueService.cleanupStaleLocks().then(() => { + // 然后启动定时清理任务 + userMessageQueueService.startCleanupTask() + }) + + // 🚦 清理服务重启后残留的并发排队计数器 + // 多实例部署时建议关闭此开关,避免新实例启动时清空其他实例的队列计数 + // 可通过 DELETE /admin/concurrency/queue 接口手动清理 + const clearQueuesOnStartup = process.env.CLEAR_CONCURRENCY_QUEUES_ON_STARTUP !== 'false' + if (clearQueuesOnStartup) { + redis.clearAllConcurrencyQueues().catch((error) => { + logger.error('❌ Error clearing concurrency queues on startup:', error) + }) + } else { + logger.info( + '🚦 Skipping concurrency queue cleanup on startup (CLEAR_CONCURRENCY_QUEUES_ON_STARTUP=false)' + ) + } + + // 🧪 启动账户定时测试调度器 + // 根据配置定期测试账户连通性并保存测试历史 + const accountTestSchedulerEnabled = + process.env.ACCOUNT_TEST_SCHEDULER_ENABLED !== 'false' && + config.accountTestScheduler?.enabled !== false + if (accountTestSchedulerEnabled) { + const accountTestSchedulerService = require('./services/accountTestSchedulerService') + accountTestSchedulerService.start() + logger.info('🧪 Account test scheduler service started') + } else { + logger.info('🧪 Account test scheduler service disabled') + } } setupGracefulShutdown() { @@ -661,6 +710,15 @@ class Application { logger.error('❌ Error stopping rate limit cleanup service:', error) } + // 停止用户消息队列清理服务 + try { + const userMessageQueueService = require('./services/userMessageQueueService') + userMessageQueueService.stopCleanupTask() + logger.info('📬 User message queue service stopped') + } catch (error) { + logger.error('❌ Error stopping user message queue service:', error) + } + // 停止费用排序索引服务 try { const costRankService = require('./services/costRankService') @@ -670,6 +728,15 @@ class Application { logger.error('❌ Error stopping cost rank service:', error) } + // 停止账户定时测试调度器 + try { + const accountTestSchedulerService = require('./services/accountTestSchedulerService') + accountTestSchedulerService.stop() + logger.info('🧪 Account test scheduler service stopped') + } catch (error) { + logger.error('❌ Error stopping account test scheduler service:', error) + } + // 🔢 清理所有并发计数(Phase 1 修复:防止重启泄漏) try { logger.info('🔢 Cleaning up all concurrency counters...') diff --git a/src/middleware/auth.js b/src/middleware/auth.js index 484d5743..a3d2311a 100644 --- a/src/middleware/auth.js +++ b/src/middleware/auth.js @@ -8,6 +8,102 @@ const redis = require('../models/redis') const ClientValidator = require('../validators/clientValidator') const ClaudeCodeValidator = require('../validators/clients/claudeCodeValidator') const claudeRelayConfigService = require('../services/claudeRelayConfigService') +const { calculateWaitTimeStats } = require('../utils/statsHelper') + +// 工具函数 +function sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +/** + * 检查排队是否过载,决定是否应该快速失败 + * 详见 design.md Decision 7: 排队健康检查与快速失败 + * + * @param {string} apiKeyId - API Key ID + * @param {number} timeoutMs - 排队超时时间(毫秒) + * @param {Object} queueConfig - 队列配置 + * @param {number} maxQueueSize - 最大排队数 + * @returns {Promise} { reject: boolean, reason?: string, estimatedWaitMs?: number, timeoutMs?: number } + */ +async function shouldRejectDueToOverload(apiKeyId, timeoutMs, queueConfig, maxQueueSize) { + try { + // 如果健康检查被禁用,直接返回不拒绝 + if (!queueConfig.concurrentRequestQueueHealthCheckEnabled) { + return { reject: false, reason: 'health_check_disabled' } + } + + // 🔑 先检查当前队列长度 + const currentQueueCount = await redis.getConcurrencyQueueCount(apiKeyId).catch(() => 0) + + // 队列为空,说明系统已恢复,跳过健康检查 + if (currentQueueCount === 0) { + return { reject: false, reason: 'queue_empty', currentQueueCount: 0 } + } + + // 🔑 关键改进:只有当队列接近满载时才进行健康检查 + // 队列长度 <= maxQueueSize * 0.5 时,认为系统有足够余量,跳过健康检查 + // 这避免了在队列较短时过于保守地拒绝请求 + // 使用 ceil 确保小队列(如 maxQueueSize=3)时阈值为 2,即队列 <=1 时跳过 + const queueLoadThreshold = Math.ceil(maxQueueSize * 0.5) + if (currentQueueCount <= queueLoadThreshold) { + return { + reject: false, + reason: 'queue_not_loaded', + currentQueueCount, + queueLoadThreshold, + maxQueueSize + } + } + + // 获取该 API Key 的等待时间样本 + const waitTimes = await redis.getQueueWaitTimes(apiKeyId) + const stats = calculateWaitTimeStats(waitTimes) + + // 样本不足(< 10),跳过健康检查,避免冷启动误判 + if (!stats || stats.sampleCount < 10) { + return { reject: false, reason: 'insufficient_samples', sampleCount: stats?.sampleCount || 0 } + } + + // P90 不可靠时也跳过(虽然 sampleCount >= 10 时 p90Unreliable 应该是 false) + if (stats.p90Unreliable) { + return { reject: false, reason: 'p90_unreliable', sampleCount: stats.sampleCount } + } + + // 计算健康阈值:P90 >= 超时时间 × 阈值 时拒绝 + const threshold = queueConfig.concurrentRequestQueueHealthThreshold || 0.8 + const maxAllowedP90 = timeoutMs * threshold + + if (stats.p90 >= maxAllowedP90) { + return { + reject: true, + reason: 'queue_overloaded', + estimatedWaitMs: stats.p90, + timeoutMs, + threshold, + sampleCount: stats.sampleCount, + currentQueueCount, + maxQueueSize + } + } + + return { reject: false, p90: stats.p90, sampleCount: stats.sampleCount, currentQueueCount } + } catch (error) { + // 健康检查出错时不阻塞请求,记录警告并继续 + logger.warn(`Health check failed for ${apiKeyId}:`, error.message) + return { reject: false, reason: 'health_check_error', error: error.message } + } +} + +// 排队轮询配置常量(可通过配置文件覆盖) +// 性能权衡:初始间隔越短响应越快,但 Redis QPS 越高 +// 当前配置:100 个等待者时约 250-300 QPS(指数退避后) +const QUEUE_POLLING_CONFIG = { + pollIntervalMs: 200, // 初始轮询间隔(毫秒)- 平衡响应速度和 Redis 压力 + maxPollIntervalMs: 2000, // 最大轮询间隔(毫秒)- 长时间等待时降低 Redis 压力 + backoffFactor: 1.5, // 指数退避系数 + jitterRatio: 0.2, // 抖动比例(±20%)- 防止惊群效应 + maxRedisFailCount: 5 // 连续 Redis 失败阈值(从 3 提高到 5,提高网络抖动容忍度) +} const FALLBACK_CONCURRENCY_CONFIG = { leaseSeconds: 300, @@ -128,9 +224,223 @@ function isTokenCountRequest(req) { return false } +/** + * 等待并发槽位(排队机制核心) + * + * 采用「先占后检查」模式避免竞态条件: + * - 每次轮询时尝试 incrConcurrency 占位 + * - 如果超限则 decrConcurrency 释放并继续等待 + * - 成功获取槽位后返回,调用方无需再次 incrConcurrency + * + * ⚠️ 重要清理责任说明: + * - 排队计数:此函数的 finally 块负责调用 decrConcurrencyQueue 清理 + * - 并发槽位:当返回 acquired=true 时,槽位已被占用(通过 incrConcurrency) + * 调用方必须在请求结束时调用 decrConcurrency 释放槽位 + * (已在 authenticateApiKey 的 finally 块中处理) + * + * @param {Object} req - Express 请求对象 + * @param {Object} res - Express 响应对象 + * @param {string} apiKeyId - API Key ID + * @param {Object} queueOptions - 配置参数 + * @returns {Promise} { acquired: boolean, reason?: string, waitTimeMs: number } + */ +async function waitForConcurrencySlot(req, res, apiKeyId, queueOptions) { + const { + concurrencyLimit, + requestId, + leaseSeconds, + timeoutMs, + pollIntervalMs, + maxPollIntervalMs, + backoffFactor, + jitterRatio, + maxRedisFailCount: configMaxRedisFailCount + } = queueOptions + + let clientDisconnected = false + // 追踪轮询过程中是否临时占用了槽位(用于异常时清理) + // 工作流程: + // 1. incrConcurrency 成功且 count <= limit 时,设置 internalSlotAcquired = true + // 2. 统计记录完成后,设置 internalSlotAcquired = false 并返回(所有权转移给调用方) + // 3. 如果在步骤 1-2 之间发生异常,finally 块会检测到 internalSlotAcquired = true 并释放槽位 + let internalSlotAcquired = false + + // 监听客户端断开事件 + // ⚠️ 重要:必须监听 socket 的事件,而不是 req 的事件! + // 原因:对于 POST 请求,当 body-parser 读取完请求体后,req(IncomingMessage 可读流) + // 的 'close' 事件会立即触发,但这不代表客户端断开连接!客户端仍在等待响应。 + // socket 的 'close' 事件才是真正的连接关闭信号。 + const { socket } = req + const onSocketClose = () => { + clientDisconnected = true + logger.debug( + `🔌 [Queue] Socket closed during queue wait for API key ${apiKeyId}, requestId: ${requestId}` + ) + } + + if (socket) { + socket.once('close', onSocketClose) + } + + // 检查 socket 是否在监听器注册前已被销毁(边界情况) + if (socket?.destroyed) { + clientDisconnected = true + } + + const startTime = Date.now() + let pollInterval = pollIntervalMs + let redisFailCount = 0 + // 优先使用配置中的值,否则使用默认值 + const maxRedisFailCount = configMaxRedisFailCount || QUEUE_POLLING_CONFIG.maxRedisFailCount + + try { + while (Date.now() - startTime < timeoutMs) { + // 检测客户端是否断开(双重检查:事件标记 + socket 状态) + // socket.destroyed 是同步检查,确保即使事件处理有延迟也能及时检测 + if (clientDisconnected || socket?.destroyed) { + redis + .incrConcurrencyQueueStats(apiKeyId, 'cancelled') + .catch((e) => logger.warn('Failed to record cancelled stat:', e)) + return { + acquired: false, + reason: 'client_disconnected', + waitTimeMs: Date.now() - startTime + } + } + + // 尝试获取槽位(先占后检查) + try { + const count = await redis.incrConcurrency(apiKeyId, requestId, leaseSeconds) + redisFailCount = 0 // 重置失败计数 + + if (count <= concurrencyLimit) { + // 成功获取槽位! + const waitTimeMs = Date.now() - startTime + + // 槽位所有权转移说明: + // 1. 此时槽位已通过 incrConcurrency 获取 + // 2. 先标记 internalSlotAcquired = true,确保异常时 finally 块能清理 + // 3. 统计操作完成后,清除标记并返回,所有权转移给调用方 + // 4. 调用方(authenticateApiKey)负责在请求结束时释放槽位 + + // 标记槽位已获取(用于异常时 finally 块清理) + internalSlotAcquired = true + + // 记录统计(非阻塞,fire-and-forget 模式) + // ⚠️ 设计说明: + // - 故意不 await 这些 Promise,因为统计记录不应阻塞请求处理 + // - 每个 Promise 都有独立的 .catch(),确保单个失败不影响其他 + // - 外层 .catch() 是防御性措施,处理 Promise.all 本身的异常 + // - 即使统计记录在函数返回后才完成/失败,也是安全的(仅日志记录) + // - 统计数据丢失可接受,不影响核心业务逻辑 + Promise.all([ + redis + .recordQueueWaitTime(apiKeyId, waitTimeMs) + .catch((e) => logger.warn('Failed to record queue wait time:', e)), + redis + .recordGlobalQueueWaitTime(waitTimeMs) + .catch((e) => logger.warn('Failed to record global wait time:', e)), + redis + .incrConcurrencyQueueStats(apiKeyId, 'success') + .catch((e) => logger.warn('Failed to increment success stats:', e)) + ]).catch((e) => logger.warn('Failed to record queue stats batch:', e)) + + // 成功返回前清除标记(所有权转移给调用方,由其负责释放) + internalSlotAcquired = false + return { acquired: true, waitTimeMs } + } + + // 超限,释放槽位继续等待 + try { + await redis.decrConcurrency(apiKeyId, requestId) + } catch (decrError) { + // 释放失败时记录警告但继续轮询 + // 下次 incrConcurrency 会自然覆盖同一 requestId 的条目 + logger.warn( + `Failed to release slot during polling for ${apiKeyId}, will retry:`, + decrError + ) + } + } catch (redisError) { + redisFailCount++ + logger.error( + `Redis error in queue polling (${redisFailCount}/${maxRedisFailCount}):`, + redisError + ) + + if (redisFailCount >= maxRedisFailCount) { + // 连续 Redis 失败,放弃排队 + return { + acquired: false, + reason: 'redis_error', + waitTimeMs: Date.now() - startTime + } + } + } + + // 指数退避等待 + await sleep(pollInterval) + + // 计算下一次轮询间隔(指数退避 + 抖动) + // 1. 先应用指数退避 + let nextInterval = pollInterval * backoffFactor + // 2. 添加抖动防止惊群效应(±jitterRatio 范围内的随机偏移) + // 抖动范围:[-jitterRatio, +jitterRatio],例如 jitterRatio=0.2 时为 ±20% + // 这是预期行为:负抖动可使间隔略微缩短,正抖动可使间隔略微延长 + // 目的是分散多个等待者的轮询时间点,避免同时请求 Redis + const jitter = nextInterval * jitterRatio * (Math.random() * 2 - 1) + nextInterval = nextInterval + jitter + // 3. 确保在合理范围内:最小 1ms,最大 maxPollIntervalMs + // Math.max(1, ...) 保证即使负抖动也不会产生 ≤0 的间隔 + pollInterval = Math.max(1, Math.min(nextInterval, maxPollIntervalMs)) + } + + // 超时 + redis + .incrConcurrencyQueueStats(apiKeyId, 'timeout') + .catch((e) => logger.warn('Failed to record timeout stat:', e)) + return { acquired: false, reason: 'timeout', waitTimeMs: Date.now() - startTime } + } finally { + // 确保清理: + // 1. 减少排队计数(排队计数在调用方已增加,这里负责减少) + try { + await redis.decrConcurrencyQueue(apiKeyId) + } catch (cleanupError) { + // 清理失败记录错误(可能导致计数泄漏,但有 TTL 保护) + logger.error( + `Failed to decrement queue count in finally block for ${apiKeyId}:`, + cleanupError + ) + } + + // 2. 如果内部获取了槽位但未正常返回(异常路径),释放槽位 + if (internalSlotAcquired) { + try { + await redis.decrConcurrency(apiKeyId, requestId) + logger.warn( + `⚠️ Released orphaned concurrency slot in finally block for ${apiKeyId}, requestId: ${requestId}` + ) + } catch (slotCleanupError) { + logger.error( + `Failed to release orphaned concurrency slot for ${apiKeyId}:`, + slotCleanupError + ) + } + } + + // 清理 socket 事件监听器 + if (socket) { + socket.removeListener('close', onSocketClose) + } + } +} + // 🔑 API Key验证中间件(优化版) const authenticateApiKey = async (req, res, next) => { const startTime = Date.now() + let authErrored = false + let concurrencyCleanup = null + let hasConcurrencySlot = false try { // 安全提取API Key,支持多种格式(包括Gemini CLI支持) @@ -265,39 +575,346 @@ const authenticateApiKey = async (req, res, next) => { } const requestId = uuidv4() + // ⚠️ 优化后的 Connection: close 设置策略 + // 问题背景:HTTP Keep-Alive 使多个请求共用同一个 TCP 连接 + // 当第一个请求正在处理,第二个请求进入排队时,它们共用同一个 socket + // 如果客户端超时关闭连接,两个请求都会受影响 + // 优化方案:只有在请求实际进入排队时才设置 Connection: close + // 未排队的请求保持 Keep-Alive,避免不必要的 TCP 握手开销 + // 详见 design.md Decision 2: Connection: close 设置时机 + // 注意:Connection: close 将在下方代码实际进入排队时设置(第 637 行左右) + + // ============================================================ + // 🔒 并发槽位状态管理说明 + // ============================================================ + // 此函数中有两个关键状态变量: + // - hasConcurrencySlot: 当前是否持有并发槽位 + // - concurrencyCleanup: 错误时调用的清理函数 + // + // 状态转换流程: + // 1. incrConcurrency 成功 → hasConcurrencySlot=true, 设置临时清理函数 + // 2. 若超限 → 释放槽位,hasConcurrencySlot=false, concurrencyCleanup=null + // 3. 若排队成功 → hasConcurrencySlot=true, 升级为完整清理函数(含 interval 清理) + // 4. 请求结束(res.close/req.close)→ 调用 decrementConcurrency 释放 + // 5. 认证错误 → finally 块调用 concurrencyCleanup 释放 + // + // 为什么需要两种清理函数? + // - 临时清理:在排队/认证过程中出错时使用,只释放槽位 + // - 完整清理:请求正常开始后使用,还需清理 leaseRenewInterval + // ============================================================ + const setTemporaryConcurrencyCleanup = () => { + concurrencyCleanup = async () => { + if (!hasConcurrencySlot) { + return + } + hasConcurrencySlot = false + try { + await redis.decrConcurrency(validation.keyData.id, requestId) + } catch (cleanupError) { + logger.error( + `Failed to decrement concurrency after auth error for key ${validation.keyData.id}:`, + cleanupError + ) + } + } + } + const currentConcurrency = await redis.incrConcurrency( validation.keyData.id, requestId, leaseSeconds ) + hasConcurrencySlot = true + setTemporaryConcurrencyCleanup() logger.api( `📈 Incremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), current: ${currentConcurrency}, limit: ${concurrencyLimit}` ) if (currentConcurrency > concurrencyLimit) { - // 如果超过限制,立即减少计数(添加 try-catch 防止异常导致并发泄漏) + // 1. 先释放刚占用的槽位 try { - const newCount = await redis.decrConcurrency(validation.keyData.id, requestId) - logger.api( - `📉 Decremented concurrency (429 rejected) for key: ${validation.keyData.id} (${validation.keyData.name}), new count: ${newCount}` - ) + await redis.decrConcurrency(validation.keyData.id, requestId) } catch (error) { logger.error( `Failed to decrement concurrency after limit exceeded for key ${validation.keyData.id}:`, error ) } - logger.security( - `🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${ - validation.keyData.name - }), current: ${currentConcurrency - 1}, limit: ${concurrencyLimit}` + hasConcurrencySlot = false + concurrencyCleanup = null + + // 2. 获取排队配置 + const queueConfig = await claudeRelayConfigService.getConfig() + + // 3. 排队功能未启用,直接返回 429(保持现有行为) + if (!queueConfig.concurrentRequestQueueEnabled) { + logger.security( + `🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${ + validation.keyData.name + }), current: ${currentConcurrency - 1}, limit: ${concurrencyLimit}` + ) + // 建议客户端在短暂延迟后重试(并发场景下通常很快会有槽位释放) + res.set('Retry-After', '1') + return res.status(429).json({ + error: 'Concurrency limit exceeded', + message: `Too many concurrent requests. Limit: ${concurrencyLimit} concurrent requests`, + currentConcurrency: currentConcurrency - 1, + concurrencyLimit + }) + } + + // 4. 计算最大排队数 + const maxQueueSize = Math.max( + concurrencyLimit * queueConfig.concurrentRequestQueueMaxSizeMultiplier, + queueConfig.concurrentRequestQueueMaxSize ) - return res.status(429).json({ - error: 'Concurrency limit exceeded', - message: `Too many concurrent requests. Limit: ${concurrencyLimit} concurrent requests`, - currentConcurrency: currentConcurrency - 1, - concurrencyLimit - }) + + // 4.5 排队健康检查:过载时快速失败 + // 详见 design.md Decision 7: 排队健康检查与快速失败 + const overloadCheck = await shouldRejectDueToOverload( + validation.keyData.id, + queueConfig.concurrentRequestQueueTimeoutMs, + queueConfig, + maxQueueSize + ) + if (overloadCheck.reject) { + // 使用健康检查返回的当前排队数,避免重复调用 Redis + const currentQueueCount = overloadCheck.currentQueueCount || 0 + logger.api( + `🚨 Queue overloaded for key: ${validation.keyData.id} (${validation.keyData.name}), ` + + `P90=${overloadCheck.estimatedWaitMs}ms, timeout=${overloadCheck.timeoutMs}ms, ` + + `threshold=${overloadCheck.threshold}, samples=${overloadCheck.sampleCount}, ` + + `concurrency=${concurrencyLimit}, queue=${currentQueueCount}/${maxQueueSize}` + ) + // 记录被拒绝的过载统计 + redis + .incrConcurrencyQueueStats(validation.keyData.id, 'rejected_overload') + .catch((e) => logger.warn('Failed to record rejected_overload stat:', e)) + // 返回 429 + Retry-After,让客户端稍后重试 + const retryAfterSeconds = 30 + res.set('Retry-After', String(retryAfterSeconds)) + return res.status(429).json({ + error: 'Queue overloaded', + message: `Queue is overloaded. Estimated wait time (${overloadCheck.estimatedWaitMs}ms) exceeds threshold. Limit: ${concurrencyLimit} concurrent requests, queue: ${currentQueueCount}/${maxQueueSize}. Please retry later.`, + currentConcurrency: concurrencyLimit, + concurrencyLimit, + queueCount: currentQueueCount, + maxQueueSize, + estimatedWaitMs: overloadCheck.estimatedWaitMs, + timeoutMs: overloadCheck.timeoutMs, + queueTimeoutMs: queueConfig.concurrentRequestQueueTimeoutMs, + retryAfterSeconds + }) + } + + // 5. 尝试进入排队(原子操作:先增加再检查,避免竞态条件) + let queueIncremented = false + try { + const newQueueCount = await redis.incrConcurrencyQueue( + validation.keyData.id, + queueConfig.concurrentRequestQueueTimeoutMs + ) + queueIncremented = true + + if (newQueueCount > maxQueueSize) { + // 超过最大排队数,立即释放并返回 429 + await redis.decrConcurrencyQueue(validation.keyData.id) + queueIncremented = false + logger.api( + `🚦 Concurrency queue full for key: ${validation.keyData.id} (${validation.keyData.name}), ` + + `queue: ${newQueueCount - 1}, maxQueue: ${maxQueueSize}` + ) + // 队列已满,建议客户端在排队超时时间后重试 + const retryAfterSeconds = Math.ceil(queueConfig.concurrentRequestQueueTimeoutMs / 1000) + res.set('Retry-After', String(retryAfterSeconds)) + return res.status(429).json({ + error: 'Concurrency queue full', + message: `Too many requests waiting in queue. Limit: ${concurrencyLimit} concurrent requests, queue: ${newQueueCount - 1}/${maxQueueSize}, timeout: ${retryAfterSeconds}s`, + currentConcurrency: concurrencyLimit, + concurrencyLimit, + queueCount: newQueueCount - 1, + maxQueueSize, + queueTimeoutMs: queueConfig.concurrentRequestQueueTimeoutMs, + retryAfterSeconds + }) + } + + // 6. 已成功进入排队,记录统计并开始等待槽位 + logger.api( + `⏳ Request entering queue for key: ${validation.keyData.id} (${validation.keyData.name}), ` + + `queue position: ${newQueueCount}` + ) + redis + .incrConcurrencyQueueStats(validation.keyData.id, 'entered') + .catch((e) => logger.warn('Failed to record entered stat:', e)) + + // ⚠️ 仅在请求实际进入排队时设置 Connection: close + // 详见 design.md Decision 2: Connection: close 设置时机 + // 未排队的请求保持 Keep-Alive,避免不必要的 TCP 握手开销 + if (!res.headersSent) { + res.setHeader('Connection', 'close') + logger.api( + `🔌 [Queue] Set Connection: close for queued request, key: ${validation.keyData.id}` + ) + } + + // ⚠️ 记录排队开始时的 socket 标识,用于排队完成后验证 + // 问题背景:HTTP Keep-Alive 连接复用时,长时间排队可能导致 socket 被其他请求使用 + // 验证方法:使用 UUID token + socket 对象引用双重验证 + // 详见 design.md Decision 1: Socket 身份验证机制 + req._crService = req._crService || {} + req._crService.queueToken = uuidv4() + req._crService.originalSocket = req.socket + req._crService.startTime = Date.now() + const savedToken = req._crService.queueToken + const savedSocket = req._crService.originalSocket + + // ⚠️ 重要:在调用前将 queueIncremented 设为 false + // 因为 waitForConcurrencySlot 的 finally 块会负责清理排队计数 + // 如果在调用后设置,当 waitForConcurrencySlot 抛出异常时 + // 外层 catch 块会重复减少计数(finally 已经减过一次) + queueIncremented = false + + const slot = await waitForConcurrencySlot(req, res, validation.keyData.id, { + concurrencyLimit, + requestId, + leaseSeconds, + timeoutMs: queueConfig.concurrentRequestQueueTimeoutMs, + pollIntervalMs: QUEUE_POLLING_CONFIG.pollIntervalMs, + maxPollIntervalMs: QUEUE_POLLING_CONFIG.maxPollIntervalMs, + backoffFactor: QUEUE_POLLING_CONFIG.backoffFactor, + jitterRatio: QUEUE_POLLING_CONFIG.jitterRatio, + maxRedisFailCount: queueConfig.concurrentRequestQueueMaxRedisFailCount + }) + + // 7. 处理排队结果 + if (!slot.acquired) { + if (slot.reason === 'client_disconnected') { + // 客户端已断开,不返回响应(连接已关闭) + logger.api( + `🔌 Client disconnected while queuing for key: ${validation.keyData.id} (${validation.keyData.name})` + ) + return + } + + if (slot.reason === 'redis_error') { + // Redis 连续失败,返回 503 + logger.error( + `❌ Redis error during queue wait for key: ${validation.keyData.id} (${validation.keyData.name})` + ) + return res.status(503).json({ + error: 'Service temporarily unavailable', + message: 'Failed to acquire concurrency slot due to internal error' + }) + } + // 排队超时(使用 api 级别,与其他排队日志保持一致) + logger.api( + `⏰ Queue timeout for key: ${validation.keyData.id} (${validation.keyData.name}), waited: ${slot.waitTimeMs}ms` + ) + // 已等待超时,建议客户端稍后重试 + // ⚠️ Retry-After 策略优化: + // - 请求已经等了完整的 timeout 时间,说明系统负载较高 + // - 过早重试(如固定 5 秒)会加剧拥塞,导致更多超时 + // - 合理策略:使用 timeout 时间的一半作为重试间隔 + // - 最小值 5 秒,最大值 30 秒,避免极端情况 + const timeoutSeconds = Math.ceil(queueConfig.concurrentRequestQueueTimeoutMs / 1000) + const retryAfterSeconds = Math.max(5, Math.min(30, Math.ceil(timeoutSeconds / 2))) + res.set('Retry-After', String(retryAfterSeconds)) + return res.status(429).json({ + error: 'Queue timeout', + message: `Request timed out waiting for concurrency slot. Limit: ${concurrencyLimit} concurrent requests, maxQueue: ${maxQueueSize}, Queue timeout: ${timeoutSeconds}s, waited: ${slot.waitTimeMs}ms`, + currentConcurrency: concurrencyLimit, + concurrencyLimit, + maxQueueSize, + queueTimeoutMs: queueConfig.concurrentRequestQueueTimeoutMs, + waitTimeMs: slot.waitTimeMs, + retryAfterSeconds + }) + } + + // 8. 排队成功,slot.acquired 表示已在 waitForConcurrencySlot 中获取到槽位 + logger.api( + `✅ Queue wait completed for key: ${validation.keyData.id} (${validation.keyData.name}), ` + + `waited: ${slot.waitTimeMs}ms` + ) + hasConcurrencySlot = true + setTemporaryConcurrencyCleanup() + + // 9. ⚠️ 关键检查:排队等待结束后,验证客户端是否还在等待响应 + // 长时间排队后,客户端可能在应用层已放弃(如 Claude Code 的超时机制), + // 但 TCP 连接仍然存活。此时继续处理请求是浪费资源。 + // 注意:如果发送了心跳,headersSent 会是 true,但这是正常的 + const postQueueSocket = req.socket + // 只检查连接是否真正断开(destroyed/writableEnded/socketDestroyed) + // headersSent 在心跳场景下是正常的,不应该作为放弃的依据 + if (res.destroyed || res.writableEnded || postQueueSocket?.destroyed) { + logger.warn( + `⚠️ Client no longer waiting after queue for key: ${validation.keyData.id} (${validation.keyData.name}), ` + + `waited: ${slot.waitTimeMs}ms | destroyed: ${res.destroyed}, ` + + `writableEnded: ${res.writableEnded}, socketDestroyed: ${postQueueSocket?.destroyed}` + ) + // 释放刚获取的槽位 + hasConcurrencySlot = false + await redis + .decrConcurrency(validation.keyData.id, requestId) + .catch((e) => logger.error('Failed to release slot after client abandoned:', e)) + // 不返回响应(客户端已不在等待) + return + } + + // 10. ⚠️ 关键检查:验证 socket 身份是否改变 + // HTTP Keep-Alive 连接复用可能导致排队期间 socket 被其他请求使用 + // 验证方法:UUID token + socket 对象引用双重验证 + // 详见 design.md Decision 1: Socket 身份验证机制 + const queueData = req._crService + const socketIdentityChanged = + !queueData || + queueData.queueToken !== savedToken || + queueData.originalSocket !== savedSocket + + if (socketIdentityChanged) { + logger.error( + `❌ [Queue] Socket identity changed during queue wait! ` + + `key: ${validation.keyData.id} (${validation.keyData.name}), ` + + `waited: ${slot.waitTimeMs}ms | ` + + `tokenMatch: ${queueData?.queueToken === savedToken}, ` + + `socketMatch: ${queueData?.originalSocket === savedSocket}` + ) + // 释放刚获取的槽位 + hasConcurrencySlot = false + await redis + .decrConcurrency(validation.keyData.id, requestId) + .catch((e) => logger.error('Failed to release slot after socket identity change:', e)) + // 记录 socket_changed 统计 + redis + .incrConcurrencyQueueStats(validation.keyData.id, 'socket_changed') + .catch((e) => logger.warn('Failed to record socket_changed stat:', e)) + // 不返回响应(socket 已被其他请求使用) + return + } + } catch (queueError) { + // 异常时清理资源,防止泄漏 + // 1. 清理排队计数(如果还没被 waitForConcurrencySlot 的 finally 清理) + if (queueIncremented) { + await redis + .decrConcurrencyQueue(validation.keyData.id) + .catch((e) => logger.error('Failed to cleanup queue count after error:', e)) + } + + // 2. 防御性清理:如果 waitForConcurrencySlot 内部获取了槽位但在返回前异常 + // 虽然这种情况极少发生(统计记录的异常会被内部捕获),但为了安全起见 + // 尝试释放可能已获取的槽位。decrConcurrency 使用 ZREM,即使成员不存在也安全 + if (hasConcurrencySlot) { + hasConcurrencySlot = false + await redis + .decrConcurrency(validation.keyData.id, requestId) + .catch((e) => + logger.error('Failed to cleanup concurrency slot after queue error:', e) + ) + } + + throw queueError + } } const renewIntervalMs = @@ -358,6 +975,7 @@ const authenticateApiKey = async (req, res, next) => { const decrementConcurrency = async () => { if (!concurrencyDecremented) { concurrencyDecremented = true + hasConcurrencySlot = false if (leaseRenewInterval) { clearInterval(leaseRenewInterval) leaseRenewInterval = null @@ -372,6 +990,11 @@ const authenticateApiKey = async (req, res, next) => { } } } + // 升级为完整清理函数(包含 leaseRenewInterval 清理逻辑) + // 此时请求已通过认证,后续由 res.close/req.close 事件触发清理 + if (hasConcurrencySlot) { + concurrencyCleanup = decrementConcurrency + } // 监听最可靠的事件(避免重复监听) // res.on('close') 是最可靠的,会在连接关闭时触发 @@ -697,6 +1320,7 @@ const authenticateApiKey = async (req, res, next) => { return next() } catch (error) { + authErrored = true const authDuration = Date.now() - startTime logger.error(`❌ Authentication middleware error (${authDuration}ms):`, { error: error.message, @@ -710,6 +1334,14 @@ const authenticateApiKey = async (req, res, next) => { error: 'Authentication error', message: 'Internal server error during authentication' }) + } finally { + if (authErrored && typeof concurrencyCleanup === 'function') { + try { + await concurrencyCleanup() + } catch (cleanupError) { + logger.error('Failed to cleanup concurrency after auth error:', cleanupError) + } + } } } @@ -1399,7 +2031,8 @@ const globalRateLimit = async (req, res, next) => // 📊 请求大小限制中间件 const requestSizeLimit = (req, res, next) => { - const maxSize = 60 * 1024 * 1024 // 60MB + const MAX_SIZE_MB = parseInt(process.env.REQUEST_MAX_SIZE_MB || '60', 10) + const maxSize = MAX_SIZE_MB * 1024 * 1024 const contentLength = parseInt(req.headers['content-length'] || '0') if (contentLength > maxSize) { diff --git a/src/models/redis.js b/src/models/redis.js index 2393f3b3..6cffa6a9 100644 --- a/src/models/redis.js +++ b/src/models/redis.js @@ -50,6 +50,18 @@ function getWeekStringInTimezone(date = new Date()) { return `${year}-W${String(weekNumber).padStart(2, '0')}` } +// 并发队列相关常量 +const QUEUE_STATS_TTL_SECONDS = 86400 * 7 // 统计计数保留 7 天 +const WAIT_TIME_TTL_SECONDS = 86400 // 等待时间样本保留 1 天(滚动窗口,无需长期保留) +// 等待时间样本数配置(提高统计置信度) +// - 每 API Key 从 100 提高到 500:提供更稳定的 P99 估计 +// - 全局从 500 提高到 2000:支持更高精度的 P99.9 分析 +// - 内存开销约 12-20KB(Redis quicklist 每元素 1-10 字节),可接受 +// 详见 design.md Decision 5: 等待时间统计样本数 +const WAIT_TIME_SAMPLES_PER_KEY = 500 // 每个 API Key 保留的等待时间样本数 +const WAIT_TIME_SAMPLES_GLOBAL = 2000 // 全局保留的等待时间样本数 +const QUEUE_TTL_BUFFER_SECONDS = 30 // 排队计数器TTL缓冲时间 + class RedisClient { constructor() { this.client = null @@ -84,7 +96,25 @@ class RedisClient { logger.warn('⚠️ Redis connection closed') }) - await this.client.connect() + // 只有在 lazyConnect 模式下才需要手动调用 connect() + // 如果 Redis 已经连接或正在连接中,则跳过 + if ( + this.client.status !== 'connecting' && + this.client.status !== 'connect' && + this.client.status !== 'ready' + ) { + await this.client.connect() + } else { + // 等待 ready 状态 + await new Promise((resolve, reject) => { + if (this.client.status === 'ready') { + resolve() + } else { + this.client.once('ready', resolve) + this.client.once('error', reject) + } + }) + } return this.client } catch (error) { logger.error('💥 Failed to connect to Redis:', error) @@ -2556,4 +2586,838 @@ redisClient.getDateStringInTimezone = getDateStringInTimezone redisClient.getHourInTimezone = getHourInTimezone redisClient.getWeekStringInTimezone = getWeekStringInTimezone +// ============== 用户消息队列相关方法 ============== + +/** + * 尝试获取用户消息队列锁 + * 使用 Lua 脚本保证原子性 + * @param {string} accountId - 账户ID + * @param {string} requestId - 请求ID + * @param {number} lockTtlMs - 锁 TTL(毫秒) + * @param {number} delayMs - 请求间隔(毫秒) + * @returns {Promise<{acquired: boolean, waitMs: number}>} + * - acquired: 是否成功获取锁 + * - waitMs: 需要等待的毫秒数(-1表示被占用需等待,>=0表示需要延迟的毫秒数) + */ +redisClient.acquireUserMessageLock = async function (accountId, requestId, lockTtlMs, delayMs) { + const lockKey = `user_msg_queue_lock:${accountId}` + const lastTimeKey = `user_msg_queue_last:${accountId}` + + const script = ` + local lockKey = KEYS[1] + local lastTimeKey = KEYS[2] + local requestId = ARGV[1] + local lockTtl = tonumber(ARGV[2]) + local delayMs = tonumber(ARGV[3]) + + -- 检查锁是否空闲 + local currentLock = redis.call('GET', lockKey) + if currentLock == false then + -- 检查是否需要延迟 + local lastTime = redis.call('GET', lastTimeKey) + local now = redis.call('TIME') + local nowMs = tonumber(now[1]) * 1000 + math.floor(tonumber(now[2]) / 1000) + + if lastTime then + local elapsed = nowMs - tonumber(lastTime) + if elapsed < delayMs then + -- 需要等待的毫秒数 + return {0, delayMs - elapsed} + end + end + + -- 获取锁 + redis.call('SET', lockKey, requestId, 'PX', lockTtl) + return {1, 0} + end + + -- 锁被占用,返回等待 + return {0, -1} + ` + + try { + const result = await this.client.eval( + script, + 2, + lockKey, + lastTimeKey, + requestId, + lockTtlMs, + delayMs + ) + return { + acquired: result[0] === 1, + waitMs: result[1] + } + } catch (error) { + logger.error(`Failed to acquire user message lock for account ${accountId}:`, error) + // 返回 redisError 标记,让上层能区分 Redis 故障和正常锁占用 + return { acquired: false, waitMs: -1, redisError: true, errorMessage: error.message } + } +} + +/** + * 释放用户消息队列锁并记录完成时间 + * @param {string} accountId - 账户ID + * @param {string} requestId - 请求ID + * @returns {Promise} 是否成功释放 + */ +redisClient.releaseUserMessageLock = async function (accountId, requestId) { + const lockKey = `user_msg_queue_lock:${accountId}` + const lastTimeKey = `user_msg_queue_last:${accountId}` + + const script = ` + local lockKey = KEYS[1] + local lastTimeKey = KEYS[2] + local requestId = ARGV[1] + + -- 验证锁持有者 + local currentLock = redis.call('GET', lockKey) + if currentLock == requestId then + -- 记录完成时间 + local now = redis.call('TIME') + local nowMs = tonumber(now[1]) * 1000 + math.floor(tonumber(now[2]) / 1000) + redis.call('SET', lastTimeKey, nowMs, 'EX', 60) -- 60秒后过期 + + -- 删除锁 + redis.call('DEL', lockKey) + return 1 + end + return 0 + ` + + try { + const result = await this.client.eval(script, 2, lockKey, lastTimeKey, requestId) + return result === 1 + } catch (error) { + logger.error(`Failed to release user message lock for account ${accountId}:`, error) + return false + } +} + +/** + * 强制释放用户消息队列锁(用于清理孤儿锁) + * @param {string} accountId - 账户ID + * @returns {Promise} 是否成功释放 + */ +redisClient.forceReleaseUserMessageLock = async function (accountId) { + const lockKey = `user_msg_queue_lock:${accountId}` + + try { + await this.client.del(lockKey) + return true + } catch (error) { + logger.error(`Failed to force release user message lock for account ${accountId}:`, error) + return false + } +} + +/** + * 获取用户消息队列统计信息(用于调试) + * @param {string} accountId - 账户ID + * @returns {Promise} 队列统计 + */ +redisClient.getUserMessageQueueStats = async function (accountId) { + const lockKey = `user_msg_queue_lock:${accountId}` + const lastTimeKey = `user_msg_queue_last:${accountId}` + + try { + const [lockHolder, lastTime, lockTtl] = await Promise.all([ + this.client.get(lockKey), + this.client.get(lastTimeKey), + this.client.pttl(lockKey) + ]) + + return { + accountId, + isLocked: !!lockHolder, + lockHolder, + lockTtlMs: lockTtl > 0 ? lockTtl : 0, + lockTtlRaw: lockTtl, // 原始 PTTL 值:>0 有TTL,-1 无过期时间,-2 键不存在 + lastCompletedAt: lastTime ? new Date(parseInt(lastTime)).toISOString() : null + } + } catch (error) { + logger.error(`Failed to get user message queue stats for account ${accountId}:`, error) + return { + accountId, + isLocked: false, + lockHolder: null, + lockTtlMs: 0, + lockTtlRaw: -2, + lastCompletedAt: null + } + } +} + +/** + * 扫描所有用户消息队列锁(用于清理任务) + * @returns {Promise} 账户ID列表 + */ +redisClient.scanUserMessageQueueLocks = async function () { + const accountIds = [] + let cursor = '0' + let iterations = 0 + const MAX_ITERATIONS = 1000 // 防止无限循环 + + try { + do { + const [newCursor, keys] = await this.client.scan( + cursor, + 'MATCH', + 'user_msg_queue_lock:*', + 'COUNT', + 100 + ) + cursor = newCursor + iterations++ + + for (const key of keys) { + const accountId = key.replace('user_msg_queue_lock:', '') + accountIds.push(accountId) + } + + // 防止无限循环 + if (iterations >= MAX_ITERATIONS) { + logger.warn( + `📬 User message queue: SCAN reached max iterations (${MAX_ITERATIONS}), stopping early`, + { foundLocks: accountIds.length } + ) + break + } + } while (cursor !== '0') + + if (accountIds.length > 0) { + logger.debug( + `📬 User message queue: scanned ${accountIds.length} lock(s) in ${iterations} iteration(s)` + ) + } + + return accountIds + } catch (error) { + logger.error('Failed to scan user message queue locks:', error) + return [] + } +} + +// ============================================ +// 🚦 API Key 并发请求排队方法 +// ============================================ + +/** + * 增加排队计数(使用 Lua 脚本确保原子性) + * @param {string} apiKeyId - API Key ID + * @param {number} [timeoutMs=60000] - 排队超时时间(毫秒),用于计算 TTL + * @returns {Promise} 增加后的排队数量 + */ +redisClient.incrConcurrencyQueue = async function (apiKeyId, timeoutMs = 60000) { + const key = `concurrency:queue:${apiKeyId}` + try { + // 使用 Lua 脚本确保 INCR 和 EXPIRE 原子执行,防止进程崩溃导致计数器泄漏 + // TTL = 超时时间 + 缓冲时间(确保键不会在请求还在等待时过期) + const ttlSeconds = Math.ceil(timeoutMs / 1000) + QUEUE_TTL_BUFFER_SECONDS + const script = ` + local count = redis.call('INCR', KEYS[1]) + redis.call('EXPIRE', KEYS[1], ARGV[1]) + return count + ` + const count = await this.client.eval(script, 1, key, String(ttlSeconds)) + logger.database( + `🚦 Incremented queue count for key ${apiKeyId}: ${count} (TTL: ${ttlSeconds}s)` + ) + return parseInt(count) + } catch (error) { + logger.error(`Failed to increment concurrency queue for ${apiKeyId}:`, error) + throw error + } +} + +/** + * 减少排队计数(使用 Lua 脚本确保原子性) + * @param {string} apiKeyId - API Key ID + * @returns {Promise} 减少后的排队数量 + */ +redisClient.decrConcurrencyQueue = async function (apiKeyId) { + const key = `concurrency:queue:${apiKeyId}` + try { + // 使用 Lua 脚本确保 DECR 和 DEL 原子执行,防止进程崩溃导致计数器残留 + const script = ` + local count = redis.call('DECR', KEYS[1]) + if count <= 0 then + redis.call('DEL', KEYS[1]) + return 0 + end + return count + ` + const count = await this.client.eval(script, 1, key) + const result = parseInt(count) + if (result === 0) { + logger.database(`🚦 Queue count for key ${apiKeyId} is 0, removed key`) + } else { + logger.database(`🚦 Decremented queue count for key ${apiKeyId}: ${result}`) + } + return result + } catch (error) { + logger.error(`Failed to decrement concurrency queue for ${apiKeyId}:`, error) + throw error + } +} + +/** + * 获取排队计数 + * @param {string} apiKeyId - API Key ID + * @returns {Promise} 当前排队数量 + */ +redisClient.getConcurrencyQueueCount = async function (apiKeyId) { + const key = `concurrency:queue:${apiKeyId}` + try { + const count = await this.client.get(key) + return parseInt(count || 0) + } catch (error) { + logger.error(`Failed to get concurrency queue count for ${apiKeyId}:`, error) + return 0 + } +} + +/** + * 清空排队计数 + * @param {string} apiKeyId - API Key ID + * @returns {Promise} 是否成功清空 + */ +redisClient.clearConcurrencyQueue = async function (apiKeyId) { + const key = `concurrency:queue:${apiKeyId}` + try { + await this.client.del(key) + logger.database(`🚦 Cleared queue count for key ${apiKeyId}`) + return true + } catch (error) { + logger.error(`Failed to clear concurrency queue for ${apiKeyId}:`, error) + return false + } +} + +/** + * 扫描所有排队计数器 + * @returns {Promise} API Key ID 列表 + */ +redisClient.scanConcurrencyQueueKeys = async function () { + const apiKeyIds = [] + let cursor = '0' + let iterations = 0 + const MAX_ITERATIONS = 1000 + + try { + do { + const [newCursor, keys] = await this.client.scan( + cursor, + 'MATCH', + 'concurrency:queue:*', + 'COUNT', + 100 + ) + cursor = newCursor + iterations++ + + for (const key of keys) { + // 排除统计和等待时间相关的键 + if ( + key.startsWith('concurrency:queue:stats:') || + key.startsWith('concurrency:queue:wait_times:') + ) { + continue + } + const apiKeyId = key.replace('concurrency:queue:', '') + apiKeyIds.push(apiKeyId) + } + + if (iterations >= MAX_ITERATIONS) { + logger.warn( + `🚦 Concurrency queue: SCAN reached max iterations (${MAX_ITERATIONS}), stopping early`, + { foundQueues: apiKeyIds.length } + ) + break + } + } while (cursor !== '0') + + return apiKeyIds + } catch (error) { + logger.error('Failed to scan concurrency queue keys:', error) + return [] + } +} + +/** + * 清理所有排队计数器(用于服务重启) + * @returns {Promise} 清理的计数器数量 + */ +redisClient.clearAllConcurrencyQueues = async function () { + let cleared = 0 + let cursor = '0' + let iterations = 0 + const MAX_ITERATIONS = 1000 + + try { + do { + const [newCursor, keys] = await this.client.scan( + cursor, + 'MATCH', + 'concurrency:queue:*', + 'COUNT', + 100 + ) + cursor = newCursor + iterations++ + + // 只删除排队计数器,保留统计数据 + const queueKeys = keys.filter( + (key) => + !key.startsWith('concurrency:queue:stats:') && + !key.startsWith('concurrency:queue:wait_times:') + ) + + if (queueKeys.length > 0) { + await this.client.del(...queueKeys) + cleared += queueKeys.length + } + + if (iterations >= MAX_ITERATIONS) { + break + } + } while (cursor !== '0') + + if (cleared > 0) { + logger.info(`🚦 Cleared ${cleared} concurrency queue counter(s) on startup`) + } + return cleared + } catch (error) { + logger.error('Failed to clear all concurrency queues:', error) + return 0 + } +} + +/** + * 增加排队统计计数(使用 Lua 脚本确保原子性) + * @param {string} apiKeyId - API Key ID + * @param {string} field - 统计字段 (entered/success/timeout/cancelled) + * @returns {Promise} 增加后的计数 + */ +redisClient.incrConcurrencyQueueStats = async function (apiKeyId, field) { + const key = `concurrency:queue:stats:${apiKeyId}` + try { + // 使用 Lua 脚本确保 HINCRBY 和 EXPIRE 原子执行 + // 防止在两者之间崩溃导致统计键没有 TTL(内存泄漏) + const script = ` + local count = redis.call('HINCRBY', KEYS[1], ARGV[1], 1) + redis.call('EXPIRE', KEYS[1], ARGV[2]) + return count + ` + const count = await this.client.eval(script, 1, key, field, String(QUEUE_STATS_TTL_SECONDS)) + return parseInt(count) + } catch (error) { + logger.error(`Failed to increment queue stats ${field} for ${apiKeyId}:`, error) + return 0 + } +} + +/** + * 获取排队统计 + * @param {string} apiKeyId - API Key ID + * @returns {Promise} 统计数据 + */ +redisClient.getConcurrencyQueueStats = async function (apiKeyId) { + const key = `concurrency:queue:stats:${apiKeyId}` + try { + const stats = await this.client.hgetall(key) + return { + entered: parseInt(stats?.entered || 0), + success: parseInt(stats?.success || 0), + timeout: parseInt(stats?.timeout || 0), + cancelled: parseInt(stats?.cancelled || 0), + socket_changed: parseInt(stats?.socket_changed || 0), + rejected_overload: parseInt(stats?.rejected_overload || 0) + } + } catch (error) { + logger.error(`Failed to get queue stats for ${apiKeyId}:`, error) + return { + entered: 0, + success: 0, + timeout: 0, + cancelled: 0, + socket_changed: 0, + rejected_overload: 0 + } + } +} + +/** + * 记录排队等待时间(按 API Key 分开存储) + * @param {string} apiKeyId - API Key ID + * @param {number} waitTimeMs - 等待时间(毫秒) + * @returns {Promise} + */ +redisClient.recordQueueWaitTime = async function (apiKeyId, waitTimeMs) { + const key = `concurrency:queue:wait_times:${apiKeyId}` + try { + // 使用 Lua 脚本确保原子性,同时设置 TTL 防止内存泄漏 + const script = ` + redis.call('LPUSH', KEYS[1], ARGV[1]) + redis.call('LTRIM', KEYS[1], 0, ARGV[2]) + redis.call('EXPIRE', KEYS[1], ARGV[3]) + return 1 + ` + await this.client.eval( + script, + 1, + key, + waitTimeMs, + WAIT_TIME_SAMPLES_PER_KEY - 1, + WAIT_TIME_TTL_SECONDS + ) + } catch (error) { + logger.error(`Failed to record queue wait time for ${apiKeyId}:`, error) + } +} + +/** + * 记录全局排队等待时间 + * @param {number} waitTimeMs - 等待时间(毫秒) + * @returns {Promise} + */ +redisClient.recordGlobalQueueWaitTime = async function (waitTimeMs) { + const key = 'concurrency:queue:wait_times:global' + try { + // 使用 Lua 脚本确保原子性,同时设置 TTL 防止内存泄漏 + const script = ` + redis.call('LPUSH', KEYS[1], ARGV[1]) + redis.call('LTRIM', KEYS[1], 0, ARGV[2]) + redis.call('EXPIRE', KEYS[1], ARGV[3]) + return 1 + ` + await this.client.eval( + script, + 1, + key, + waitTimeMs, + WAIT_TIME_SAMPLES_GLOBAL - 1, + WAIT_TIME_TTL_SECONDS + ) + } catch (error) { + logger.error('Failed to record global queue wait time:', error) + } +} + +/** + * 获取全局等待时间列表 + * @returns {Promise} 等待时间列表 + */ +redisClient.getGlobalQueueWaitTimes = async function () { + const key = 'concurrency:queue:wait_times:global' + try { + const samples = await this.client.lrange(key, 0, -1) + return samples.map(Number) + } catch (error) { + logger.error('Failed to get global queue wait times:', error) + return [] + } +} + +/** + * 获取指定 API Key 的等待时间列表 + * @param {string} apiKeyId - API Key ID + * @returns {Promise} 等待时间列表 + */ +redisClient.getQueueWaitTimes = async function (apiKeyId) { + const key = `concurrency:queue:wait_times:${apiKeyId}` + try { + const samples = await this.client.lrange(key, 0, -1) + return samples.map(Number) + } catch (error) { + logger.error(`Failed to get queue wait times for ${apiKeyId}:`, error) + return [] + } +} + +/** + * 扫描所有排队统计键 + * @returns {Promise} API Key ID 列表 + */ +redisClient.scanConcurrencyQueueStatsKeys = async function () { + const apiKeyIds = [] + let cursor = '0' + let iterations = 0 + const MAX_ITERATIONS = 1000 + + try { + do { + const [newCursor, keys] = await this.client.scan( + cursor, + 'MATCH', + 'concurrency:queue:stats:*', + 'COUNT', + 100 + ) + cursor = newCursor + iterations++ + + for (const key of keys) { + const apiKeyId = key.replace('concurrency:queue:stats:', '') + apiKeyIds.push(apiKeyId) + } + + if (iterations >= MAX_ITERATIONS) { + break + } + } while (cursor !== '0') + + return apiKeyIds + } catch (error) { + logger.error('Failed to scan concurrency queue stats keys:', error) + return [] + } +} + +// ============================================================================ +// 账户测试历史相关操作 +// ============================================================================ + +const ACCOUNT_TEST_HISTORY_MAX = 5 // 保留最近5次测试记录 +const ACCOUNT_TEST_HISTORY_TTL = 86400 * 30 // 30天过期 +const ACCOUNT_TEST_CONFIG_TTL = 86400 * 365 // 测试配置保留1年(用户通常长期使用) + +/** + * 保存账户测试结果 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 (claude/gemini/openai等) + * @param {Object} testResult - 测试结果对象 + * @param {boolean} testResult.success - 是否成功 + * @param {string} testResult.message - 测试消息/响应 + * @param {number} testResult.latencyMs - 延迟毫秒数 + * @param {string} testResult.error - 错误信息(如有) + * @param {string} testResult.timestamp - 测试时间戳 + */ +redisClient.saveAccountTestResult = async function (accountId, platform, testResult) { + const key = `account:test_history:${platform}:${accountId}` + try { + const record = JSON.stringify({ + ...testResult, + timestamp: testResult.timestamp || new Date().toISOString() + }) + + // 使用 LPUSH + LTRIM 保持最近5条记录 + const client = this.getClientSafe() + await client.lpush(key, record) + await client.ltrim(key, 0, ACCOUNT_TEST_HISTORY_MAX - 1) + await client.expire(key, ACCOUNT_TEST_HISTORY_TTL) + + logger.debug(`📝 Saved test result for ${platform} account ${accountId}`) + } catch (error) { + logger.error(`Failed to save test result for ${accountId}:`, error) + } +} + +/** + * 获取账户测试历史 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 测试历史记录数组(最新在前) + */ +redisClient.getAccountTestHistory = async function (accountId, platform) { + const key = `account:test_history:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const records = await client.lrange(key, 0, -1) + return records.map((r) => JSON.parse(r)) + } catch (error) { + logger.error(`Failed to get test history for ${accountId}:`, error) + return [] + } +} + +/** + * 获取账户最新测试结果 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 最新测试结果 + */ +redisClient.getAccountLatestTestResult = async function (accountId, platform) { + const key = `account:test_history:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const record = await client.lindex(key, 0) + return record ? JSON.parse(record) : null + } catch (error) { + logger.error(`Failed to get latest test result for ${accountId}:`, error) + return null + } +} + +/** + * 批量获取多个账户的测试历史 + * @param {Array<{accountId: string, platform: string}>} accounts - 账户列表 + * @returns {Promise} 以 accountId 为 key 的测试历史映射 + */ +redisClient.getAccountsTestHistory = async function (accounts) { + const result = {} + try { + const client = this.getClientSafe() + const pipeline = client.pipeline() + + for (const { accountId, platform } of accounts) { + const key = `account:test_history:${platform}:${accountId}` + pipeline.lrange(key, 0, -1) + } + + const responses = await pipeline.exec() + + accounts.forEach(({ accountId }, index) => { + const [err, records] = responses[index] + if (!err && records) { + result[accountId] = records.map((r) => JSON.parse(r)) + } else { + result[accountId] = [] + } + }) + } catch (error) { + logger.error('Failed to get batch test history:', error) + } + return result +} + +/** + * 保存定时测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {Object} config - 配置对象 + * @param {boolean} config.enabled - 是否启用定时测试 + * @param {string} config.cronExpression - Cron 表达式 (如 "0 8 * * *" 表示每天8点) + * @param {string} config.model - 测试使用的模型 + */ +redisClient.saveAccountTestConfig = async function (accountId, platform, testConfig) { + const key = `account:test_config:${platform}:${accountId}` + try { + const client = this.getClientSafe() + await client.hset(key, { + enabled: testConfig.enabled ? 'true' : 'false', + cronExpression: testConfig.cronExpression || '0 8 * * *', // 默认每天早上8点 + model: testConfig.model || 'claude-sonnet-4-5-20250929', // 默认模型 + updatedAt: new Date().toISOString() + }) + // 设置过期时间(1年) + await client.expire(key, ACCOUNT_TEST_CONFIG_TTL) + } catch (error) { + logger.error(`Failed to save test config for ${accountId}:`, error) + } +} + +/** + * 获取定时测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 配置对象 + */ +redisClient.getAccountTestConfig = async function (accountId, platform) { + const key = `account:test_config:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const testConfig = await client.hgetall(key) + if (!testConfig || Object.keys(testConfig).length === 0) { + return null + } + // 向后兼容:如果存在旧的 testHour 字段,转换为 cron 表达式 + let { cronExpression } = testConfig + if (!cronExpression && testConfig.testHour) { + const hour = parseInt(testConfig.testHour, 10) + cronExpression = `0 ${hour} * * *` + } + return { + enabled: testConfig.enabled === 'true', + cronExpression: cronExpression || '0 8 * * *', + model: testConfig.model || 'claude-sonnet-4-5-20250929', + updatedAt: testConfig.updatedAt + } + } catch (error) { + logger.error(`Failed to get test config for ${accountId}:`, error) + return null + } +} + +/** + * 获取所有启用定时测试的账户 + * @param {string} platform - 平台类型 + * @returns {Promise} 账户ID列表及 cron 配置 + */ +redisClient.getEnabledTestAccounts = async function (platform) { + const accountIds = [] + let cursor = '0' + + try { + const client = this.getClientSafe() + do { + const [newCursor, keys] = await client.scan( + cursor, + 'MATCH', + `account:test_config:${platform}:*`, + 'COUNT', + 100 + ) + cursor = newCursor + + for (const key of keys) { + const testConfig = await client.hgetall(key) + if (testConfig && testConfig.enabled === 'true') { + const accountId = key.replace(`account:test_config:${platform}:`, '') + // 向后兼容:如果存在旧的 testHour 字段,转换为 cron 表达式 + let { cronExpression } = testConfig + if (!cronExpression && testConfig.testHour) { + const hour = parseInt(testConfig.testHour, 10) + cronExpression = `0 ${hour} * * *` + } + accountIds.push({ + accountId, + cronExpression: cronExpression || '0 8 * * *', + model: testConfig.model || 'claude-sonnet-4-5-20250929' + }) + } + } + } while (cursor !== '0') + + return accountIds + } catch (error) { + logger.error(`Failed to get enabled test accounts for ${platform}:`, error) + return [] + } +} + +/** + * 保存账户上次测试时间(用于调度器判断是否需要测试) + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + */ +redisClient.setAccountLastTestTime = async function (accountId, platform) { + const key = `account:last_test:${platform}:${accountId}` + try { + const client = this.getClientSafe() + await client.set(key, Date.now().toString(), 'EX', 86400 * 7) // 7天过期 + } catch (error) { + logger.error(`Failed to set last test time for ${accountId}:`, error) + } +} + +/** + * 获取账户上次测试时间 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 上次测试时间戳 + */ +redisClient.getAccountLastTestTime = async function (accountId, platform) { + const key = `account:last_test:${platform}:${accountId}` + try { + const client = this.getClientSafe() + const timestamp = await client.get(key) + return timestamp ? parseInt(timestamp, 10) : null + } catch (error) { + logger.error(`Failed to get last test time for ${accountId}:`, error) + return null + } +} + module.exports = redisClient diff --git a/src/routes/admin/apiKeys.js b/src/routes/admin/apiKeys.js index 8e444067..d88444bd 100644 --- a/src/routes/admin/apiKeys.js +++ b/src/routes/admin/apiKeys.js @@ -945,6 +945,30 @@ async function calculateKeyStats(keyId, timeRange, startDate, endDate) { allTimeCost = parseFloat((await client.get(totalCostKey)) || '0') } + // 🔧 FIX: 对于 "全部时间" 时间范围,直接使用 allTimeCost + // 因为 usage:*:model:daily:* 键有 30 天 TTL,旧数据已经过期 + if (timeRange === 'all' && allTimeCost > 0) { + logger.debug(`📊 使用 allTimeCost 计算 timeRange='all': ${allTimeCost}`) + + return { + requests: 0, // 旧数据详情不可用 + tokens: 0, + inputTokens: 0, + outputTokens: 0, + cacheCreateTokens: 0, + cacheReadTokens: 0, + cost: allTimeCost, + formattedCost: CostCalculator.formatCost(allTimeCost), + // 实时限制数据(始终返回,不受时间范围影响) + dailyCost, + currentWindowCost, + windowRemainingSeconds, + windowStartTime, + windowEndTime, + allTimeCost + } + } + // 只在启用了窗口限制时查询窗口数据 if (rateLimitWindow > 0) { const costCountKey = `rate_limit:cost:${keyId}` @@ -1006,12 +1030,10 @@ async function calculateKeyStats(keyId, timeRange, startDate, endDate) { const modelStatsMap = new Map() let totalRequests = 0 - // 用于去重:只统计日数据,避免与月数据重复 + // 用于去重:先统计月数据,避免与日数据重复 const dailyKeyPattern = /usage:.+:model:daily:(.+):\d{4}-\d{2}-\d{2}$/ const monthlyKeyPattern = /usage:.+:model:monthly:(.+):\d{4}-\d{2}$/ - - // 检查是否有日数据 - const hasDailyData = uniqueKeys.some((key) => dailyKeyPattern.test(key)) + const currentMonth = `${tzDate.getUTCFullYear()}-${String(tzDate.getUTCMonth() + 1).padStart(2, '0')}` for (let i = 0; i < results.length; i++) { const [err, data] = results[i] @@ -1038,8 +1060,12 @@ async function calculateKeyStats(keyId, timeRange, startDate, endDate) { continue } - // 如果有日数据,则跳过月数据以避免重复 - if (hasDailyData && isMonthly) { + // 跳过当前月的月数据 + if (isMonthly && key.includes(`:${currentMonth}`)) { + continue + } + // 跳过非当前月的日数据 + if (!isMonthly && !key.includes(`:${currentMonth}-`)) { continue } diff --git a/src/routes/admin/claudeAccounts.js b/src/routes/admin/claudeAccounts.js index 13dd1a63..52791374 100644 --- a/src/routes/admin/claudeAccounts.js +++ b/src/routes/admin/claudeAccounts.js @@ -9,6 +9,7 @@ const router = express.Router() const claudeAccountService = require('../../services/claudeAccountService') const claudeRelayService = require('../../services/claudeRelayService') const accountGroupService = require('../../services/accountGroupService') +const accountTestSchedulerService = require('../../services/accountTestSchedulerService') const apiKeyService = require('../../services/apiKeyService') const redis = require('../../models/redis') const { authenticateAdmin } = require('../../middleware/auth') @@ -277,7 +278,7 @@ router.post('/claude-accounts/oauth-with-cookie', authenticateAdmin, async (req, logger.info('🍪 Starting Cookie-based OAuth authorization', { sessionKeyLength: trimmedSessionKey.length, - sessionKeyPrefix: trimmedSessionKey.substring(0, 10) + '...', + sessionKeyPrefix: `${trimmedSessionKey.substring(0, 10)}...`, hasProxy: !!proxy }) @@ -326,7 +327,7 @@ router.post('/claude-accounts/setup-token-with-cookie', authenticateAdmin, async logger.info('🍪 Starting Cookie-based Setup Token authorization', { sessionKeyLength: trimmedSessionKey.length, - sessionKeyPrefix: trimmedSessionKey.substring(0, 10) + '...', + sessionKeyPrefix: `${trimmedSessionKey.substring(0, 10)}...`, hasProxy: !!proxy }) @@ -583,7 +584,8 @@ router.post('/claude-accounts', authenticateAdmin, async (req, res) => { useUnifiedClientId, unifiedClientId, expiresAt, - extInfo + extInfo, + maxConcurrency } = req.body if (!name) { @@ -628,7 +630,8 @@ router.post('/claude-accounts', authenticateAdmin, async (req, res) => { useUnifiedClientId: useUnifiedClientId === true, // 默认为false unifiedClientId: unifiedClientId || '', // 统一的客户端标识 expiresAt: expiresAt || null, // 账户订阅到期时间 - extInfo: extInfo || null + extInfo: extInfo || null, + maxConcurrency: maxConcurrency || 0 // 账户级串行队列:0=使用全局配置,>0=强制启用 }) // 如果是分组类型,将账户添加到分组 @@ -903,4 +906,219 @@ router.post('/claude-accounts/:accountId/test', authenticateAdmin, async (req, r } }) +// ============================================================================ +// 账户定时测试相关端点 +// ============================================================================ + +// 获取账户测试历史 +router.get('/claude-accounts/:accountId/test-history', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + + try { + const history = await redis.getAccountTestHistory(accountId, 'claude') + return res.json({ + success: true, + data: { + accountId, + platform: 'claude', + history + } + }) + } catch (error) { + logger.error(`❌ Failed to get test history for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to get test history', + message: error.message + }) + } +}) + +// 获取账户定时测试配置 +router.get('/claude-accounts/:accountId/test-config', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + + try { + const testConfig = await redis.getAccountTestConfig(accountId, 'claude') + return res.json({ + success: true, + data: { + accountId, + platform: 'claude', + config: testConfig || { + enabled: false, + cronExpression: '0 8 * * *', + model: 'claude-sonnet-4-5-20250929' + } + } + }) + } catch (error) { + logger.error(`❌ Failed to get test config for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to get test config', + message: error.message + }) + } +}) + +// 设置账户定时测试配置 +router.put('/claude-accounts/:accountId/test-config', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + const { enabled, cronExpression, model } = req.body + + try { + // 验证 enabled 参数 + if (typeof enabled !== 'boolean') { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'enabled must be a boolean' + }) + } + + // 验证 cronExpression 参数 + if (!cronExpression || typeof cronExpression !== 'string') { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'cronExpression is required and must be a string' + }) + } + + // 限制 cronExpression 长度防止 DoS + const MAX_CRON_LENGTH = 100 + if (cronExpression.length > MAX_CRON_LENGTH) { + return res.status(400).json({ + error: 'Invalid parameter', + message: `cronExpression too long (max ${MAX_CRON_LENGTH} characters)` + }) + } + + // 使用 service 的方法验证 cron 表达式 + if (!accountTestSchedulerService.validateCronExpression(cronExpression)) { + return res.status(400).json({ + error: 'Invalid parameter', + message: `Invalid cron expression: ${cronExpression}. Format: "minute hour day month weekday" (e.g., "0 8 * * *" for daily at 8:00)` + }) + } + + // 验证模型参数 + const testModel = model || 'claude-sonnet-4-5-20250929' + if (typeof testModel !== 'string' || testModel.length > 256) { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'model must be a valid string (max 256 characters)' + }) + } + + // 检查账户是否存在 + const account = await claudeAccountService.getAccount(accountId) + if (!account) { + return res.status(404).json({ + error: 'Account not found', + message: `Claude account ${accountId} not found` + }) + } + + // 保存配置 + await redis.saveAccountTestConfig(accountId, 'claude', { + enabled, + cronExpression, + model: testModel + }) + + logger.success( + `📝 Updated test config for Claude account ${accountId}: enabled=${enabled}, cronExpression=${cronExpression}, model=${testModel}` + ) + + return res.json({ + success: true, + message: 'Test config updated successfully', + data: { + accountId, + platform: 'claude', + config: { enabled, cronExpression, model: testModel } + } + }) + } catch (error) { + logger.error(`❌ Failed to update test config for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to update test config', + message: error.message + }) + } +}) + +// 手动触发账户测试(非流式,返回JSON结果) +router.post('/claude-accounts/:accountId/test-sync', authenticateAdmin, async (req, res) => { + const { accountId } = req.params + + try { + // 检查账户是否存在 + const account = await claudeAccountService.getAccount(accountId) + if (!account) { + return res.status(404).json({ + error: 'Account not found', + message: `Claude account ${accountId} not found` + }) + } + + logger.info(`🧪 Manual sync test triggered for Claude account: ${accountId}`) + + // 执行测试 + const testResult = await claudeRelayService.testAccountConnectionSync(accountId) + + // 保存测试结果到历史 + await redis.saveAccountTestResult(accountId, 'claude', testResult) + await redis.setAccountLastTestTime(accountId, 'claude') + + return res.json({ + success: true, + data: { + accountId, + platform: 'claude', + result: testResult + } + }) + } catch (error) { + logger.error(`❌ Failed to run sync test for account ${accountId}:`, error) + return res.status(500).json({ + error: 'Failed to run test', + message: error.message + }) + } +}) + +// 批量获取多个账户的测试历史 +router.post('/claude-accounts/batch-test-history', authenticateAdmin, async (req, res) => { + const { accountIds } = req.body + + try { + if (!Array.isArray(accountIds) || accountIds.length === 0) { + return res.status(400).json({ + error: 'Invalid parameter', + message: 'accountIds must be a non-empty array' + }) + } + + // 限制批量查询数量 + const limitedIds = accountIds.slice(0, 100) + + const accounts = limitedIds.map((accountId) => ({ + accountId, + platform: 'claude' + })) + + const historyMap = await redis.getAccountsTestHistory(accounts) + + return res.json({ + success: true, + data: historyMap + }) + } catch (error) { + logger.error('❌ Failed to get batch test history:', error) + return res.status(500).json({ + error: 'Failed to get batch test history', + message: error.message + }) + } +}) + module.exports = router diff --git a/src/routes/admin/claudeRelayConfig.js b/src/routes/admin/claudeRelayConfig.js index e3c78ef4..a41207a9 100644 --- a/src/routes/admin/claudeRelayConfig.js +++ b/src/routes/admin/claudeRelayConfig.js @@ -40,7 +40,14 @@ router.put('/claude-relay-config', authenticateAdmin, async (req, res) => { claudeCodeOnlyEnabled, globalSessionBindingEnabled, sessionBindingErrorMessage, - sessionBindingTtlDays + sessionBindingTtlDays, + userMessageQueueEnabled, + userMessageQueueDelayMs, + userMessageQueueTimeoutMs, + concurrentRequestQueueEnabled, + concurrentRequestQueueMaxSize, + concurrentRequestQueueMaxSizeMultiplier, + concurrentRequestQueueTimeoutMs } = req.body // 验证输入 @@ -78,15 +85,117 @@ router.put('/claude-relay-config', authenticateAdmin, async (req, res) => { } } + // 验证用户消息队列配置 + if (userMessageQueueEnabled !== undefined && typeof userMessageQueueEnabled !== 'boolean') { + return res.status(400).json({ error: 'userMessageQueueEnabled must be a boolean' }) + } + + if (userMessageQueueDelayMs !== undefined) { + if ( + typeof userMessageQueueDelayMs !== 'number' || + userMessageQueueDelayMs < 0 || + userMessageQueueDelayMs > 10000 + ) { + return res + .status(400) + .json({ error: 'userMessageQueueDelayMs must be a number between 0 and 10000' }) + } + } + + if (userMessageQueueTimeoutMs !== undefined) { + if ( + typeof userMessageQueueTimeoutMs !== 'number' || + userMessageQueueTimeoutMs < 1000 || + userMessageQueueTimeoutMs > 300000 + ) { + return res + .status(400) + .json({ error: 'userMessageQueueTimeoutMs must be a number between 1000 and 300000' }) + } + } + + // 验证并发请求排队配置 + if ( + concurrentRequestQueueEnabled !== undefined && + typeof concurrentRequestQueueEnabled !== 'boolean' + ) { + return res.status(400).json({ error: 'concurrentRequestQueueEnabled must be a boolean' }) + } + + if (concurrentRequestQueueMaxSize !== undefined) { + if ( + typeof concurrentRequestQueueMaxSize !== 'number' || + !Number.isInteger(concurrentRequestQueueMaxSize) || + concurrentRequestQueueMaxSize < 1 || + concurrentRequestQueueMaxSize > 100 + ) { + return res + .status(400) + .json({ error: 'concurrentRequestQueueMaxSize must be an integer between 1 and 100' }) + } + } + + if (concurrentRequestQueueMaxSizeMultiplier !== undefined) { + // 使用 Number.isFinite() 同时排除 NaN、Infinity、-Infinity 和非数字类型 + if ( + !Number.isFinite(concurrentRequestQueueMaxSizeMultiplier) || + concurrentRequestQueueMaxSizeMultiplier < 0 || + concurrentRequestQueueMaxSizeMultiplier > 10 + ) { + return res.status(400).json({ + error: 'concurrentRequestQueueMaxSizeMultiplier must be a finite number between 0 and 10' + }) + } + } + + if (concurrentRequestQueueTimeoutMs !== undefined) { + if ( + typeof concurrentRequestQueueTimeoutMs !== 'number' || + !Number.isInteger(concurrentRequestQueueTimeoutMs) || + concurrentRequestQueueTimeoutMs < 5000 || + concurrentRequestQueueTimeoutMs > 300000 + ) { + return res.status(400).json({ + error: + 'concurrentRequestQueueTimeoutMs must be an integer between 5000 and 300000 (5 seconds to 5 minutes)' + }) + } + } + const updateData = {} - if (claudeCodeOnlyEnabled !== undefined) + if (claudeCodeOnlyEnabled !== undefined) { updateData.claudeCodeOnlyEnabled = claudeCodeOnlyEnabled - if (globalSessionBindingEnabled !== undefined) + } + if (globalSessionBindingEnabled !== undefined) { updateData.globalSessionBindingEnabled = globalSessionBindingEnabled - if (sessionBindingErrorMessage !== undefined) + } + if (sessionBindingErrorMessage !== undefined) { updateData.sessionBindingErrorMessage = sessionBindingErrorMessage - if (sessionBindingTtlDays !== undefined) + } + if (sessionBindingTtlDays !== undefined) { updateData.sessionBindingTtlDays = sessionBindingTtlDays + } + if (userMessageQueueEnabled !== undefined) { + updateData.userMessageQueueEnabled = userMessageQueueEnabled + } + if (userMessageQueueDelayMs !== undefined) { + updateData.userMessageQueueDelayMs = userMessageQueueDelayMs + } + if (userMessageQueueTimeoutMs !== undefined) { + updateData.userMessageQueueTimeoutMs = userMessageQueueTimeoutMs + } + if (concurrentRequestQueueEnabled !== undefined) { + updateData.concurrentRequestQueueEnabled = concurrentRequestQueueEnabled + } + if (concurrentRequestQueueMaxSize !== undefined) { + updateData.concurrentRequestQueueMaxSize = concurrentRequestQueueMaxSize + } + if (concurrentRequestQueueMaxSizeMultiplier !== undefined) { + updateData.concurrentRequestQueueMaxSizeMultiplier = concurrentRequestQueueMaxSizeMultiplier + } + if (concurrentRequestQueueTimeoutMs !== undefined) { + updateData.concurrentRequestQueueTimeoutMs = concurrentRequestQueueTimeoutMs + } const updatedConfig = await claudeRelayConfigService.updateConfig( updateData, diff --git a/src/routes/admin/concurrency.js b/src/routes/admin/concurrency.js index 80fee22c..9325b5a8 100644 --- a/src/routes/admin/concurrency.js +++ b/src/routes/admin/concurrency.js @@ -7,26 +7,40 @@ const express = require('express') const router = express.Router() const redis = require('../../models/redis') const logger = require('../../utils/logger') +const { authenticateAdmin } = require('../../middleware/auth') +const { calculateWaitTimeStats } = require('../../utils/statsHelper') /** * GET /admin/concurrency * 获取所有并发状态 */ -router.get('/concurrency', async (req, res) => { +router.get('/concurrency', authenticateAdmin, async (req, res) => { try { const status = await redis.getAllConcurrencyStatus() + // 为每个 API Key 获取排队计数 + const statusWithQueue = await Promise.all( + status.map(async (s) => { + const queueCount = await redis.getConcurrencyQueueCount(s.apiKeyId) + return { + ...s, + queueCount + } + }) + ) + // 计算汇总统计 const summary = { - totalKeys: status.length, - totalActiveRequests: status.reduce((sum, s) => sum + s.activeCount, 0), - totalExpiredRequests: status.reduce((sum, s) => sum + s.expiredCount, 0) + totalKeys: statusWithQueue.length, + totalActiveRequests: statusWithQueue.reduce((sum, s) => sum + s.activeCount, 0), + totalExpiredRequests: statusWithQueue.reduce((sum, s) => sum + s.expiredCount, 0), + totalQueuedRequests: statusWithQueue.reduce((sum, s) => sum + s.queueCount, 0) } res.json({ success: true, summary, - concurrencyStatus: status + concurrencyStatus: statusWithQueue }) } catch (error) { logger.error('❌ Failed to get concurrency status:', error) @@ -39,17 +53,171 @@ router.get('/concurrency', async (req, res) => { }) /** - * GET /admin/concurrency/:apiKeyId - * 获取特定 API Key 的并发状态详情 + * GET /admin/concurrency-queue/stats + * 获取排队统计信息 */ -router.get('/concurrency/:apiKeyId', async (req, res) => { +router.get('/concurrency-queue/stats', authenticateAdmin, async (req, res) => { try { - const { apiKeyId } = req.params - const status = await redis.getConcurrencyStatus(apiKeyId) + // 获取所有有统计数据的 API Key + const statsKeys = await redis.scanConcurrencyQueueStatsKeys() + const queueKeys = await redis.scanConcurrencyQueueKeys() + + // 合并所有相关的 API Key + const allApiKeyIds = [...new Set([...statsKeys, ...queueKeys])] + + // 获取各 API Key 的详细统计 + const perKeyStats = await Promise.all( + allApiKeyIds.map(async (apiKeyId) => { + const [queueCount, stats, waitTimes] = await Promise.all([ + redis.getConcurrencyQueueCount(apiKeyId), + redis.getConcurrencyQueueStats(apiKeyId), + redis.getQueueWaitTimes(apiKeyId) + ]) + + return { + apiKeyId, + currentQueueCount: queueCount, + stats, + waitTimeStats: calculateWaitTimeStats(waitTimes) + } + }) + ) + + // 获取全局等待时间统计 + const globalWaitTimes = await redis.getGlobalQueueWaitTimes() + const globalWaitTimeStats = calculateWaitTimeStats(globalWaitTimes) + + // 计算全局汇总 + const globalStats = { + totalEntered: perKeyStats.reduce((sum, s) => sum + s.stats.entered, 0), + totalSuccess: perKeyStats.reduce((sum, s) => sum + s.stats.success, 0), + totalTimeout: perKeyStats.reduce((sum, s) => sum + s.stats.timeout, 0), + totalCancelled: perKeyStats.reduce((sum, s) => sum + s.stats.cancelled, 0), + totalSocketChanged: perKeyStats.reduce((sum, s) => sum + (s.stats.socket_changed || 0), 0), + totalRejectedOverload: perKeyStats.reduce( + (sum, s) => sum + (s.stats.rejected_overload || 0), + 0 + ), + currentTotalQueued: perKeyStats.reduce((sum, s) => sum + s.currentQueueCount, 0), + // 队列资源利用率指标 + peakQueueSize: + perKeyStats.length > 0 ? Math.max(...perKeyStats.map((s) => s.currentQueueCount)) : 0, + avgQueueSize: + perKeyStats.length > 0 + ? Math.round( + perKeyStats.reduce((sum, s) => sum + s.currentQueueCount, 0) / perKeyStats.length + ) + : 0, + activeApiKeys: perKeyStats.filter((s) => s.currentQueueCount > 0).length + } + + // 计算成功率 + if (globalStats.totalEntered > 0) { + globalStats.successRate = Math.round( + (globalStats.totalSuccess / globalStats.totalEntered) * 100 + ) + globalStats.timeoutRate = Math.round( + (globalStats.totalTimeout / globalStats.totalEntered) * 100 + ) + globalStats.cancelledRate = Math.round( + (globalStats.totalCancelled / globalStats.totalEntered) * 100 + ) + } + + // 从全局等待时间统计中提取关键指标 + if (globalWaitTimeStats) { + globalStats.avgWaitTimeMs = globalWaitTimeStats.avg + globalStats.p50WaitTimeMs = globalWaitTimeStats.p50 + globalStats.p90WaitTimeMs = globalWaitTimeStats.p90 + globalStats.p99WaitTimeMs = globalWaitTimeStats.p99 + // 多实例采样策略标记(详见 design.md Decision 9) + // 全局 P90 仅用于可视化和监控,不用于系统决策 + // 健康检查使用 API Key 级别的 P90(每 Key 独立采样) + globalWaitTimeStats.globalP90ForVisualizationOnly = true + } res.json({ success: true, - concurrencyStatus: status + globalStats, + globalWaitTimeStats, + perKeyStats + }) + } catch (error) { + logger.error('❌ Failed to get queue stats:', error) + res.status(500).json({ + success: false, + error: 'Failed to get queue stats', + message: error.message + }) + } +}) + +/** + * DELETE /admin/concurrency-queue/:apiKeyId + * 清理特定 API Key 的排队计数 + */ +router.delete('/concurrency-queue/:apiKeyId', authenticateAdmin, async (req, res) => { + try { + const { apiKeyId } = req.params + await redis.clearConcurrencyQueue(apiKeyId) + + logger.warn(`🧹 Admin ${req.admin?.username || 'unknown'} cleared queue for key ${apiKeyId}`) + + res.json({ + success: true, + message: `Successfully cleared queue for API key ${apiKeyId}` + }) + } catch (error) { + logger.error(`❌ Failed to clear queue for ${req.params.apiKeyId}:`, error) + res.status(500).json({ + success: false, + error: 'Failed to clear queue', + message: error.message + }) + } +}) + +/** + * DELETE /admin/concurrency-queue + * 清理所有排队计数 + */ +router.delete('/concurrency-queue', authenticateAdmin, async (req, res) => { + try { + const cleared = await redis.clearAllConcurrencyQueues() + + logger.warn(`🧹 Admin ${req.admin?.username || 'unknown'} cleared ALL queues`) + + res.json({ + success: true, + message: 'Successfully cleared all queues', + cleared + }) + } catch (error) { + logger.error('❌ Failed to clear all queues:', error) + res.status(500).json({ + success: false, + error: 'Failed to clear all queues', + message: error.message + }) + } +}) + +/** + * GET /admin/concurrency/:apiKeyId + * 获取特定 API Key 的并发状态详情 + */ +router.get('/concurrency/:apiKeyId', authenticateAdmin, async (req, res) => { + try { + const { apiKeyId } = req.params + const status = await redis.getConcurrencyStatus(apiKeyId) + const queueCount = await redis.getConcurrencyQueueCount(apiKeyId) + + res.json({ + success: true, + concurrencyStatus: { + ...status, + queueCount + } }) } catch (error) { logger.error(`❌ Failed to get concurrency status for ${req.params.apiKeyId}:`, error) @@ -65,7 +233,7 @@ router.get('/concurrency/:apiKeyId', async (req, res) => { * DELETE /admin/concurrency/:apiKeyId * 强制清理特定 API Key 的并发计数 */ -router.delete('/concurrency/:apiKeyId', async (req, res) => { +router.delete('/concurrency/:apiKeyId', authenticateAdmin, async (req, res) => { try { const { apiKeyId } = req.params const result = await redis.forceClearConcurrency(apiKeyId) @@ -93,7 +261,7 @@ router.delete('/concurrency/:apiKeyId', async (req, res) => { * DELETE /admin/concurrency * 强制清理所有并发计数 */ -router.delete('/concurrency', async (req, res) => { +router.delete('/concurrency', authenticateAdmin, async (req, res) => { try { const result = await redis.forceClearAllConcurrency() @@ -118,7 +286,7 @@ router.delete('/concurrency', async (req, res) => { * POST /admin/concurrency/cleanup * 清理过期的并发条目(不影响活跃请求) */ -router.post('/concurrency/cleanup', async (req, res) => { +router.post('/concurrency/cleanup', authenticateAdmin, async (req, res) => { try { const { apiKeyId } = req.body const result = await redis.cleanupExpiredConcurrency(apiKeyId || null) diff --git a/src/routes/api.js b/src/routes/api.js index b2ee018c..8ca1bb08 100644 --- a/src/routes/api.js +++ b/src/routes/api.js @@ -38,6 +38,73 @@ function queueRateLimitUpdate(rateLimitInfo, usageSummary, model, context = '') }) } +/** + * 判断是否为旧会话(污染的会话) + * Claude Code 发送的请求特点: + * - messages 数组通常只有 1 个元素 + * - 历史对话记录嵌套在单个 message 的 content 数组中 + * - content 数组中包含 开头的系统注入内容 + * + * 污染会话的特征: + * 1. messages.length > 1 + * 2. messages.length === 1 但 content 中有多个用户输入 + * 3. "warmup" 请求:单条简单消息 + 无 tools(真正新会话会带 tools) + * + * @param {Object} body - 请求体 + * @returns {boolean} 是否为旧会话 + */ +function isOldSession(body) { + const messages = body?.messages + const tools = body?.tools + + if (!messages || messages.length === 0) { + return false + } + + // 1. 多条消息 = 旧会话 + if (messages.length > 1) { + return true + } + + // 2. 单条消息,分析 content + const firstMessage = messages[0] + const content = firstMessage?.content + + if (!content) { + return false + } + + // 如果 content 是字符串,只有一条输入,需要检查 tools + if (typeof content === 'string') { + // 有 tools = 正常新会话,无 tools = 可疑 + return !tools || tools.length === 0 + } + + // 如果 content 是数组,统计非 system-reminder 的元素 + if (Array.isArray(content)) { + const userInputs = content.filter((item) => { + if (item.type !== 'text') { + return false + } + const text = item.text || '' + // 剔除以 开头的 + return !text.trimStart().startsWith('') + }) + + // 多个用户输入 = 旧会话 + if (userInputs.length > 1) { + return true + } + + // Warmup 检测:单个消息 + 无 tools = 旧会话 + if (userInputs.length === 1 && (!tools || tools.length === 0)) { + return true + } + } + + return false +} + // 🔧 共享的消息处理函数 async function handleMessagesRequest(req, res) { try { @@ -123,12 +190,42 @@ async function handleMessagesRequest(req, res) { ) if (isStream) { + // 🔍 检查客户端连接是否仍然有效(可能在并发排队等待期间断开) + if (res.destroyed || res.socket?.destroyed || res.writableEnded) { + logger.warn( + `⚠️ Client disconnected before stream response could start for key: ${req.apiKey?.name || 'unknown'}` + ) + return undefined + } + // 流式响应 - 只使用官方真实usage数据 res.setHeader('Content-Type', 'text/event-stream') res.setHeader('Cache-Control', 'no-cache') res.setHeader('Connection', 'keep-alive') res.setHeader('Access-Control-Allow-Origin', '*') res.setHeader('X-Accel-Buffering', 'no') // 禁用 Nginx 缓冲 + // ⚠️ 检查 headers 是否已发送(可能在排队心跳时已设置) + if (!res.headersSent) { + res.setHeader('Content-Type', 'text/event-stream') + res.setHeader('Cache-Control', 'no-cache') + // ⚠️ 关键修复:尊重 auth.js 提前设置的 Connection: close + // 当并发队列功能启用时,auth.js 会设置 Connection: close 来禁用 Keep-Alive + // 这里只在没有设置过 Connection 头时才设置 keep-alive + const existingConnection = res.getHeader('Connection') + if (!existingConnection) { + res.setHeader('Connection', 'keep-alive') + } else { + logger.api( + `🔌 [STREAM] Preserving existing Connection header: ${existingConnection} for key: ${req.apiKey?.name || 'unknown'}` + ) + } + res.setHeader('Access-Control-Allow-Origin', '*') + res.setHeader('X-Accel-Buffering', 'no') // 禁用 Nginx 缓冲 + } else { + logger.debug( + `📤 [STREAM] Headers already sent, skipping setHeader for key: ${req.apiKey?.name || 'unknown'}` + ) + } // 禁用 Nagle 算法,确保数据立即发送 if (res.socket && typeof res.socket.setNoDelay === 'function') { @@ -233,19 +330,18 @@ async function handleMessagesRequest(req, res) { } // 🔗 在成功调度后建立会话绑定(仅 claude-official 类型) - // claude-official 只接受:1) 新会话(messages.length=1) 2) 已绑定的会话 + // claude-official 只接受:1) 新会话 2) 已绑定的会话 if ( needSessionBinding && originalSessionIdForBinding && accountId && accountType === 'claude-official' ) { - // 🚫 新会话必须 messages.length === 1 - const messages = req.body?.messages - if (messages && messages.length > 1) { + // 🚫 检测旧会话(污染的会话) + if (isOldSession(req.body)) { const cfg = await claudeRelayConfigService.getConfig() logger.warn( - `🚫 New session with messages.length > 1 rejected: sessionId=${originalSessionIdForBinding}, messages.length=${messages.length}` + `🚫 Old session rejected: sessionId=${originalSessionIdForBinding}, messages.length=${req.body?.messages?.length}, tools.length=${req.body?.tools?.length || 0}, isOldSession=true` ) return res.status(400).json({ error: { @@ -591,12 +687,61 @@ async function handleMessagesRequest(req, res) { } }, 1000) // 1秒后检查 } else { + // 🔍 检查客户端连接是否仍然有效(可能在并发排队等待期间断开) + if (res.destroyed || res.socket?.destroyed || res.writableEnded) { + logger.warn( + `⚠️ Client disconnected before non-stream request could start for key: ${req.apiKey?.name || 'unknown'}` + ) + return undefined + } + // 非流式响应 - 只使用官方真实usage数据 logger.info('📄 Starting non-streaming request', { apiKeyId: req.apiKey.id, apiKeyName: req.apiKey.name }) + // 📊 监听 socket 事件以追踪连接状态变化 + const nonStreamSocket = res.socket + let _clientClosedConnection = false + let _socketCloseTime = null + + if (nonStreamSocket) { + const onSocketEnd = () => { + _clientClosedConnection = true + _socketCloseTime = Date.now() + logger.warn( + `⚠️ [NON-STREAM] Socket 'end' event - client sent FIN | key: ${req.apiKey?.name}, ` + + `requestId: ${req.requestId}, elapsed: ${Date.now() - startTime}ms` + ) + } + const onSocketClose = () => { + _clientClosedConnection = true + logger.warn( + `⚠️ [NON-STREAM] Socket 'close' event | key: ${req.apiKey?.name}, ` + + `requestId: ${req.requestId}, elapsed: ${Date.now() - startTime}ms, ` + + `hadError: ${nonStreamSocket.destroyed}` + ) + } + const onSocketError = (err) => { + logger.error( + `❌ [NON-STREAM] Socket error | key: ${req.apiKey?.name}, ` + + `requestId: ${req.requestId}, error: ${err.message}` + ) + } + + nonStreamSocket.once('end', onSocketEnd) + nonStreamSocket.once('close', onSocketClose) + nonStreamSocket.once('error', onSocketError) + + // 清理监听器(在响应结束后) + res.once('finish', () => { + nonStreamSocket.removeListener('end', onSocketEnd) + nonStreamSocket.removeListener('close', onSocketClose) + nonStreamSocket.removeListener('error', onSocketError) + }) + } + // 生成会话哈希用于sticky会话 const sessionHash = sessionHelper.generateSessionHash(req.body) @@ -684,19 +829,18 @@ async function handleMessagesRequest(req, res) { } // 🔗 在成功调度后建立会话绑定(非流式,仅 claude-official 类型) - // claude-official 只接受:1) 新会话(messages.length=1) 2) 已绑定的会话 + // claude-official 只接受:1) 新会话 2) 已绑定的会话 if ( needSessionBindingNonStream && originalSessionIdForBindingNonStream && accountId && accountType === 'claude-official' ) { - // 🚫 新会话必须 messages.length === 1 - const messages = req.body?.messages - if (messages && messages.length > 1) { + // 🚫 检测旧会话(污染的会话) + if (isOldSession(req.body)) { const cfg = await claudeRelayConfigService.getConfig() logger.warn( - `🚫 New session with messages.length > 1 rejected (non-stream): sessionId=${originalSessionIdForBindingNonStream}, messages.length=${messages.length}` + `🚫 Old session rejected (non-stream): sessionId=${originalSessionIdForBindingNonStream}, messages.length=${req.body?.messages?.length}, tools.length=${req.body?.tools?.length || 0}, isOldSession=true` ) return res.status(400).json({ error: { @@ -802,6 +946,15 @@ async function handleMessagesRequest(req, res) { bodyLength: response.body ? response.body.length : 0 }) + // 🔍 检查客户端连接是否仍然有效 + // 在长时间请求过程中,客户端可能已经断开连接(超时、用户取消等) + if (res.destroyed || res.socket?.destroyed || res.writableEnded) { + logger.warn( + `⚠️ Client disconnected before non-stream response could be sent for key: ${req.apiKey?.name || 'unknown'}` + ) + return undefined + } + res.status(response.statusCode) // 设置响应头,避免 Content-Length 和 Transfer-Encoding 冲突 @@ -867,10 +1020,12 @@ async function handleMessagesRequest(req, res) { logger.warn('⚠️ No usage data found in Claude API JSON response') } + // 使用 Express 内建的 res.json() 发送响应(简单可靠) res.json(jsonData) } catch (parseError) { logger.warn('⚠️ Failed to parse Claude API response as JSON:', parseError.message) logger.info('📄 Raw response body:', response.body) + // 使用 Express 内建的 res.send() 发送响应(简单可靠) res.send(response.body) } @@ -1157,6 +1312,41 @@ router.post('/v1/messages/count_tokens', authenticateApiKey, async (req, res) => }) } + // 🔗 会话绑定验证(与 messages 端点保持一致) + const originalSessionId = claudeRelayConfigService.extractOriginalSessionId(req.body) + const sessionValidation = await claudeRelayConfigService.validateNewSession( + req.body, + originalSessionId + ) + + if (!sessionValidation.valid) { + logger.warn( + `🚫 Session binding validation failed (count_tokens): ${sessionValidation.code} for session ${originalSessionId}` + ) + return res.status(400).json({ + error: { + type: 'session_binding_error', + message: sessionValidation.error + } + }) + } + + // 🔗 检测旧会话(污染的会话)- 仅对需要绑定的新会话检查 + if (sessionValidation.isNewSession && originalSessionId) { + if (isOldSession(req.body)) { + const cfg = await claudeRelayConfigService.getConfig() + logger.warn( + `🚫 Old session rejected (count_tokens): sessionId=${originalSessionId}, messages.length=${req.body?.messages?.length}, tools.length=${req.body?.tools?.length || 0}, isOldSession=true` + ) + return res.status(400).json({ + error: { + type: 'session_binding_error', + message: cfg.sessionBindingErrorMessage || '你的本地session已污染,请清理后使用。' + } + }) + } + } + logger.info(`🔢 Processing token count request for key: ${req.apiKey.name}`) const sessionHash = sessionHelper.generateSessionHash(req.body) diff --git a/src/routes/apiStats.js b/src/routes/apiStats.js index 308b18c6..62614b65 100644 --- a/src/routes/apiStats.js +++ b/src/routes/apiStats.js @@ -206,74 +206,85 @@ router.post('/api/user-stats', async (req, res) => { // 获取验证结果中的完整keyData(包含isActive状态和cost信息) const fullKeyData = keyData - // 计算总费用 - 使用与模型统计相同的逻辑(按模型分别计算) + // 🔧 FIX: 使用 allTimeCost 而不是扫描月度键 + // 计算总费用 - 优先使用持久化的总费用计数器 let totalCost = 0 let formattedCost = '$0.000000' try { const client = redis.getClientSafe() - // 获取所有月度模型统计(与model-stats接口相同的逻辑) - const allModelKeys = await client.keys(`usage:${keyId}:model:monthly:*:*`) - const modelUsageMap = new Map() + // 读取累积的总费用(没有 TTL 的持久键) + const totalCostKey = `usage:cost:total:${keyId}` + const allTimeCost = parseFloat((await client.get(totalCostKey)) || '0') - for (const key of allModelKeys) { - const modelMatch = key.match(/usage:.+:model:monthly:(.+):(\d{4}-\d{2})$/) - if (!modelMatch) { - continue - } + if (allTimeCost > 0) { + totalCost = allTimeCost + formattedCost = CostCalculator.formatCost(allTimeCost) + logger.debug(`📊 使用 allTimeCost 计算用户统计: ${allTimeCost}`) + } else { + // Fallback: 如果 allTimeCost 为空(旧键),尝试月度键 + const allModelKeys = await client.keys(`usage:${keyId}:model:monthly:*:*`) + const modelUsageMap = new Map() - const model = modelMatch[1] - const data = await client.hgetall(key) - - if (data && Object.keys(data).length > 0) { - if (!modelUsageMap.has(model)) { - modelUsageMap.set(model, { - inputTokens: 0, - outputTokens: 0, - cacheCreateTokens: 0, - cacheReadTokens: 0 - }) + for (const key of allModelKeys) { + const modelMatch = key.match(/usage:.+:model:monthly:(.+):(\d{4}-\d{2})$/) + if (!modelMatch) { + continue } - const modelUsage = modelUsageMap.get(model) - modelUsage.inputTokens += parseInt(data.inputTokens) || 0 - modelUsage.outputTokens += parseInt(data.outputTokens) || 0 - modelUsage.cacheCreateTokens += parseInt(data.cacheCreateTokens) || 0 - modelUsage.cacheReadTokens += parseInt(data.cacheReadTokens) || 0 - } - } + const model = modelMatch[1] + const data = await client.hgetall(key) - // 按模型计算费用并汇总 - for (const [model, usage] of modelUsageMap) { - const usageData = { - input_tokens: usage.inputTokens, - output_tokens: usage.outputTokens, - cache_creation_input_tokens: usage.cacheCreateTokens, - cache_read_input_tokens: usage.cacheReadTokens + if (data && Object.keys(data).length > 0) { + if (!modelUsageMap.has(model)) { + modelUsageMap.set(model, { + inputTokens: 0, + outputTokens: 0, + cacheCreateTokens: 0, + cacheReadTokens: 0 + }) + } + + const modelUsage = modelUsageMap.get(model) + modelUsage.inputTokens += parseInt(data.inputTokens) || 0 + modelUsage.outputTokens += parseInt(data.outputTokens) || 0 + modelUsage.cacheCreateTokens += parseInt(data.cacheCreateTokens) || 0 + modelUsage.cacheReadTokens += parseInt(data.cacheReadTokens) || 0 + } } - const costResult = CostCalculator.calculateCost(usageData, model) - totalCost += costResult.costs.total - } + // 按模型计算费用并汇总 + for (const [model, usage] of modelUsageMap) { + const usageData = { + input_tokens: usage.inputTokens, + output_tokens: usage.outputTokens, + cache_creation_input_tokens: usage.cacheCreateTokens, + cache_read_input_tokens: usage.cacheReadTokens + } - // 如果没有模型级别的详细数据,回退到总体数据计算 - if (modelUsageMap.size === 0 && fullKeyData.usage?.total?.allTokens > 0) { - const usage = fullKeyData.usage.total - const costUsage = { - input_tokens: usage.inputTokens || 0, - output_tokens: usage.outputTokens || 0, - cache_creation_input_tokens: usage.cacheCreateTokens || 0, - cache_read_input_tokens: usage.cacheReadTokens || 0 + const costResult = CostCalculator.calculateCost(usageData, model) + totalCost += costResult.costs.total } - const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022') - totalCost = costResult.costs.total - } + // 如果没有模型级别的详细数据,回退到总体数据计算 + if (modelUsageMap.size === 0 && fullKeyData.usage?.total?.allTokens > 0) { + const usage = fullKeyData.usage.total + const costUsage = { + input_tokens: usage.inputTokens || 0, + output_tokens: usage.outputTokens || 0, + cache_creation_input_tokens: usage.cacheCreateTokens || 0, + cache_read_input_tokens: usage.cacheReadTokens || 0 + } - formattedCost = CostCalculator.formatCost(totalCost) + const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022') + totalCost = costResult.costs.total + } + + formattedCost = CostCalculator.formatCost(totalCost) + } } catch (error) { - logger.warn(`Failed to calculate detailed cost for key ${keyId}:`, error) + logger.warn(`Failed to calculate cost for key ${keyId}:`, error) // 回退到简单计算 if (fullKeyData.usage?.total?.allTokens > 0) { const usage = fullKeyData.usage.total diff --git a/src/services/accountTestSchedulerService.js b/src/services/accountTestSchedulerService.js new file mode 100644 index 00000000..59b4c6af --- /dev/null +++ b/src/services/accountTestSchedulerService.js @@ -0,0 +1,420 @@ +/** + * 账户定时测试调度服务 + * 使用 node-cron 支持 crontab 表达式,为每个账户创建独立的定时任务 + */ + +const cron = require('node-cron') +const redis = require('../models/redis') +const logger = require('../utils/logger') + +class AccountTestSchedulerService { + constructor() { + // 存储每个账户的 cron 任务: Map + this.scheduledTasks = new Map() + // 定期刷新配置的间隔 (毫秒) + this.refreshIntervalMs = 60 * 1000 + this.refreshInterval = null + // 当前正在测试的账户 + this.testingAccounts = new Set() + // 是否已启动 + this.isStarted = false + } + + /** + * 验证 cron 表达式是否有效 + * @param {string} cronExpression - cron 表达式 + * @returns {boolean} + */ + validateCronExpression(cronExpression) { + // 长度检查(防止 DoS) + if (!cronExpression || cronExpression.length > 100) { + return false + } + return cron.validate(cronExpression) + } + + /** + * 启动调度器 + */ + async start() { + if (this.isStarted) { + logger.warn('⚠️ Account test scheduler is already running') + return + } + + this.isStarted = true + logger.info('🚀 Starting account test scheduler service (node-cron mode)') + + // 初始化所有已配置账户的定时任务 + await this._refreshAllTasks() + + // 定期刷新配置,以便动态添加/修改的配置能生效 + this.refreshInterval = setInterval(() => { + this._refreshAllTasks() + }, this.refreshIntervalMs) + + logger.info( + `📅 Account test scheduler started (refreshing configs every ${this.refreshIntervalMs / 1000}s)` + ) + } + + /** + * 停止调度器 + */ + stop() { + if (this.refreshInterval) { + clearInterval(this.refreshInterval) + this.refreshInterval = null + } + + // 停止所有 cron 任务 + for (const [accountKey, taskInfo] of this.scheduledTasks.entries()) { + taskInfo.task.stop() + logger.debug(`🛑 Stopped cron task for ${accountKey}`) + } + this.scheduledTasks.clear() + + this.isStarted = false + logger.info('🛑 Account test scheduler stopped') + } + + /** + * 刷新所有账户的定时任务 + * @private + */ + async _refreshAllTasks() { + try { + const platforms = ['claude', 'gemini', 'openai'] + const activeAccountKeys = new Set() + + // 并行加载所有平台的配置 + const allEnabledAccounts = await Promise.all( + platforms.map((platform) => + redis + .getEnabledTestAccounts(platform) + .then((accounts) => accounts.map((acc) => ({ ...acc, platform }))) + .catch((error) => { + logger.warn(`⚠️ Failed to load test accounts for platform ${platform}:`, error) + return [] + }) + ) + ) + + // 展平平台数据 + const flatAccounts = allEnabledAccounts.flat() + + for (const { accountId, cronExpression, model, platform } of flatAccounts) { + if (!cronExpression) { + logger.warn( + `⚠️ Account ${accountId} (${platform}) has no valid cron expression, skipping` + ) + continue + } + + const accountKey = `${platform}:${accountId}` + activeAccountKeys.add(accountKey) + + // 检查是否需要更新任务 + const existingTask = this.scheduledTasks.get(accountKey) + if (existingTask) { + // 如果 cron 表达式和模型都没变,不需要更新 + if (existingTask.cronExpression === cronExpression && existingTask.model === model) { + continue + } + // 配置变了,停止旧任务 + existingTask.task.stop() + logger.info(`🔄 Updating cron task for ${accountKey}: ${cronExpression}, model: ${model}`) + } else { + logger.info(`➕ Creating cron task for ${accountKey}: ${cronExpression}, model: ${model}`) + } + + // 创建新的 cron 任务 + this._createCronTask(accountId, platform, cronExpression, model) + } + + // 清理已删除或禁用的账户任务 + for (const [accountKey, taskInfo] of this.scheduledTasks.entries()) { + if (!activeAccountKeys.has(accountKey)) { + taskInfo.task.stop() + this.scheduledTasks.delete(accountKey) + logger.info(`➖ Removed cron task for ${accountKey} (disabled or deleted)`) + } + } + } catch (error) { + logger.error('❌ Error refreshing account test tasks:', error) + } + } + + /** + * 为单个账户创建 cron 任务 + * @param {string} accountId + * @param {string} platform + * @param {string} cronExpression + * @param {string} model - 测试使用的模型 + * @private + */ + _createCronTask(accountId, platform, cronExpression, model) { + const accountKey = `${platform}:${accountId}` + + // 验证 cron 表达式 + if (!this.validateCronExpression(cronExpression)) { + logger.error(`❌ Invalid cron expression for ${accountKey}: ${cronExpression}`) + return + } + + const task = cron.schedule( + cronExpression, + async () => { + await this._runAccountTest(accountId, platform, model) + }, + { + scheduled: true, + timezone: process.env.TZ || 'Asia/Shanghai' + } + ) + + this.scheduledTasks.set(accountKey, { + task, + cronExpression, + model, + accountId, + platform + }) + } + + /** + * 执行单个账户测试 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {string} model - 测试使用的模型 + * @private + */ + async _runAccountTest(accountId, platform, model) { + const accountKey = `${platform}:${accountId}` + + // 避免重复测试 + if (this.testingAccounts.has(accountKey)) { + logger.debug(`⏳ Account ${accountKey} is already being tested, skipping`) + return + } + + this.testingAccounts.add(accountKey) + + try { + logger.info( + `🧪 Running scheduled test for ${platform} account: ${accountId} (model: ${model})` + ) + + let testResult + + // 根据平台调用对应的测试方法 + switch (platform) { + case 'claude': + testResult = await this._testClaudeAccount(accountId, model) + break + case 'gemini': + testResult = await this._testGeminiAccount(accountId, model) + break + case 'openai': + testResult = await this._testOpenAIAccount(accountId, model) + break + default: + testResult = { + success: false, + error: `Unsupported platform: ${platform}`, + timestamp: new Date().toISOString() + } + } + + // 保存测试结果 + await redis.saveAccountTestResult(accountId, platform, testResult) + + // 更新最后测试时间 + await redis.setAccountLastTestTime(accountId, platform) + + // 记录日志 + if (testResult.success) { + logger.info( + `✅ Scheduled test passed for ${platform} account ${accountId} (${testResult.latencyMs}ms)` + ) + } else { + logger.warn( + `❌ Scheduled test failed for ${platform} account ${accountId}: ${testResult.error}` + ) + } + + return testResult + } catch (error) { + logger.error(`❌ Error testing ${platform} account ${accountId}:`, error) + + const errorResult = { + success: false, + error: error.message, + timestamp: new Date().toISOString() + } + + await redis.saveAccountTestResult(accountId, platform, errorResult) + await redis.setAccountLastTestTime(accountId, platform) + + return errorResult + } finally { + this.testingAccounts.delete(accountKey) + } + } + + /** + * 测试 Claude 账户 + * @param {string} accountId + * @param {string} model - 测试使用的模型 + * @private + */ + async _testClaudeAccount(accountId, model) { + const claudeRelayService = require('./claudeRelayService') + return await claudeRelayService.testAccountConnectionSync(accountId, model) + } + + /** + * 测试 Gemini 账户 + * @param {string} _accountId + * @param {string} _model + * @private + */ + async _testGeminiAccount(_accountId, _model) { + // Gemini 测试暂时返回未实现 + return { + success: false, + error: 'Gemini scheduled test not implemented yet', + timestamp: new Date().toISOString() + } + } + + /** + * 测试 OpenAI 账户 + * @param {string} _accountId + * @param {string} _model + * @private + */ + async _testOpenAIAccount(_accountId, _model) { + // OpenAI 测试暂时返回未实现 + return { + success: false, + error: 'OpenAI scheduled test not implemented yet', + timestamp: new Date().toISOString() + } + } + + /** + * 手动触发账户测试 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {string} model - 测试使用的模型 + * @returns {Promise} 测试结果 + */ + async triggerTest(accountId, platform, model = 'claude-sonnet-4-5-20250929') { + logger.info(`🎯 Manual test triggered for ${platform} account: ${accountId} (model: ${model})`) + return await this._runAccountTest(accountId, platform, model) + } + + /** + * 获取账户测试历史 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} 测试历史 + */ + async getTestHistory(accountId, platform) { + return await redis.getAccountTestHistory(accountId, platform) + } + + /** + * 获取账户测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @returns {Promise} + */ + async getTestConfig(accountId, platform) { + return await redis.getAccountTestConfig(accountId, platform) + } + + /** + * 设置账户测试配置 + * @param {string} accountId - 账户ID + * @param {string} platform - 平台类型 + * @param {Object} testConfig - 测试配置 { enabled: boolean, cronExpression: string, model: string } + * @returns {Promise} + */ + async setTestConfig(accountId, platform, testConfig) { + // 验证 cron 表达式 + if (testConfig.cronExpression && !this.validateCronExpression(testConfig.cronExpression)) { + throw new Error(`Invalid cron expression: ${testConfig.cronExpression}`) + } + + await redis.saveAccountTestConfig(accountId, platform, testConfig) + logger.info( + `📝 Test config updated for ${platform} account ${accountId}: enabled=${testConfig.enabled}, cronExpression=${testConfig.cronExpression}, model=${testConfig.model}` + ) + + // 立即刷新任务,使配置立即生效 + if (this.isStarted) { + await this._refreshAllTasks() + } + } + + /** + * 更新单个账户的定时任务(配置变更时调用) + * @param {string} accountId + * @param {string} platform + */ + async refreshAccountTask(accountId, platform) { + if (!this.isStarted) { + return + } + + const accountKey = `${platform}:${accountId}` + const testConfig = await redis.getAccountTestConfig(accountId, platform) + + // 停止现有任务 + const existingTask = this.scheduledTasks.get(accountKey) + if (existingTask) { + existingTask.task.stop() + this.scheduledTasks.delete(accountKey) + } + + // 如果启用且有有效的 cron 表达式,创建新任务 + if (testConfig?.enabled && testConfig?.cronExpression) { + this._createCronTask(accountId, platform, testConfig.cronExpression, testConfig.model) + logger.info( + `🔄 Refreshed cron task for ${accountKey}: ${testConfig.cronExpression}, model: ${testConfig.model}` + ) + } + } + + /** + * 获取调度器状态 + * @returns {Object} + */ + getStatus() { + const tasks = [] + for (const [accountKey, taskInfo] of this.scheduledTasks.entries()) { + tasks.push({ + accountKey, + accountId: taskInfo.accountId, + platform: taskInfo.platform, + cronExpression: taskInfo.cronExpression, + model: taskInfo.model + }) + } + + return { + running: this.isStarted, + refreshIntervalMs: this.refreshIntervalMs, + scheduledTasksCount: this.scheduledTasks.size, + scheduledTasks: tasks, + currentlyTesting: Array.from(this.testingAccounts) + } + } +} + +// 单例模式 +const accountTestSchedulerService = new AccountTestSchedulerService() + +module.exports = accountTestSchedulerService diff --git a/src/services/bedrockRelayService.js b/src/services/bedrockRelayService.js index e27dfd5c..d04e42b2 100644 --- a/src/services/bedrockRelayService.js +++ b/src/services/bedrockRelayService.js @@ -6,6 +6,7 @@ const { const { fromEnv } = require('@aws-sdk/credential-providers') const logger = require('../utils/logger') const config = require('../../config/config') +const userMessageQueueService = require('./userMessageQueueService') class BedrockRelayService { constructor() { @@ -69,7 +70,68 @@ class BedrockRelayService { // 处理非流式请求 async handleNonStreamRequest(requestBody, bedrockAccount = null) { + const accountId = bedrockAccount?.id + let queueLockAcquired = false + let queueRequestId = null + try { + // 📬 用户消息队列处理 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error('❌ accountId missing for queue lock in Bedrock handleNonStreamRequest') + throw new Error('accountId missing for queue lock') + } + const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for Bedrock account ${accountId}`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + return { + statusCode, + headers: { + 'Content-Type': 'application/json', + 'x-user-message-queue-error': errorType + }, + body: JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + }), + success: false + } + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for Bedrock account ${accountId}, requestId: ${queueRequestId}` + ) + } + } + const modelId = this._selectModel(requestBody, bedrockAccount) const region = this._selectRegion(modelId, bedrockAccount) const client = this._getBedrockClient(region, bedrockAccount) @@ -90,6 +152,23 @@ class BedrockRelayService { const response = await client.send(command) const duration = Date.now() - startTime + // 📬 请求已发送成功,立即释放队列锁(无需等待响应处理完成) + // 因为限流基于请求发送时刻计算(RPM),不是请求完成时刻 + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for Bedrock account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for Bedrock account ${accountId}:`, + releaseError.message + ) + } + } + // 解析响应 const responseBody = JSON.parse(new TextDecoder().decode(response.body)) const claudeResponse = this._convertFromBedrockFormat(responseBody) @@ -106,12 +185,94 @@ class BedrockRelayService { } catch (error) { logger.error('❌ Bedrock非流式请求失败:', error) throw this._handleBedrockError(error) + } finally { + // 📬 释放用户消息队列锁(兜底,正常情况下已在请求发送后提前释放) + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for Bedrock account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for Bedrock account ${accountId}:`, + releaseError.message + ) + } + } } } // 处理流式请求 async handleStreamRequest(requestBody, bedrockAccount = null, res) { + const accountId = bedrockAccount?.id + let queueLockAcquired = false + let queueRequestId = null + try { + // 📬 用户消息队列处理 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error('❌ accountId missing for queue lock in Bedrock handleStreamRequest') + throw new Error('accountId missing for queue lock') + } + const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + stream: true, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for Bedrock account ${accountId} (stream)`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + if (!res.headersSent) { + const existingConnection = res.getHeader ? res.getHeader('Connection') : null + res.writeHead(statusCode, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: existingConnection || 'keep-alive', + 'x-user-message-queue-error': errorType + }) + } + const errorEvent = `event: error\ndata: ${JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + })}\n\n` + res.write(errorEvent) + res.write('data: [DONE]\n\n') + res.end() + return { success: false, error: errorType } + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for Bedrock account ${accountId} (stream), requestId: ${queueRequestId}` + ) + } + } + const modelId = this._selectModel(requestBody, bedrockAccount) const region = this._selectRegion(modelId, bedrockAccount) const client = this._getBedrockClient(region, bedrockAccount) @@ -131,11 +292,35 @@ class BedrockRelayService { const startTime = Date.now() const response = await client.send(command) + // 📬 请求已发送成功,立即释放队列锁(无需等待响应处理完成) + // 因为限流基于请求发送时刻计算(RPM),不是请求完成时刻 + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for Bedrock stream account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for Bedrock stream account ${accountId}:`, + releaseError.message + ) + } + } + // 设置SSE响应头 + // ⚠️ 关键修复:尊重 auth.js 提前设置的 Connection: close + const existingConnection = res.getHeader ? res.getHeader('Connection') : null + if (existingConnection) { + logger.debug( + `🔌 [Bedrock Stream] Preserving existing Connection header: ${existingConnection}` + ) + } res.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive', + Connection: existingConnection || 'keep-alive', 'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Headers': 'Content-Type, Authorization' }) @@ -191,6 +376,21 @@ class BedrockRelayService { res.end() throw this._handleBedrockError(error) + } finally { + // 📬 释放用户消息队列锁(兜底,正常情况下已在请求发送后提前释放) + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for Bedrock stream account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for Bedrock stream account ${accountId}:`, + releaseError.message + ) + } + } } } diff --git a/src/services/ccrRelayService.js b/src/services/ccrRelayService.js index 50ad7b58..d5f97c9f 100644 --- a/src/services/ccrRelayService.js +++ b/src/services/ccrRelayService.js @@ -3,6 +3,8 @@ const ccrAccountService = require('./ccrAccountService') const logger = require('../utils/logger') const config = require('../../config/config') const { parseVendorPrefixedModel } = require('../utils/modelHelper') +const userMessageQueueService = require('./userMessageQueueService') +const { isStreamWritable } = require('../utils/streamHelper') class CcrRelayService { constructor() { @@ -21,8 +23,67 @@ class CcrRelayService { ) { let abortController = null let account = null + let queueLockAcquired = false + let queueRequestId = null try { + // 📬 用户消息队列处理 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error('❌ accountId missing for queue lock in CCR relayRequest') + throw new Error('accountId missing for queue lock') + } + const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for CCR account ${accountId}`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + return { + statusCode, + headers: { + 'Content-Type': 'application/json', + 'x-user-message-queue-error': errorType + }, + body: JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + }), + accountId + } + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for CCR account ${accountId}, requestId: ${queueRequestId}` + ) + } + } + // 获取账户信息 account = await ccrAccountService.getAccount(accountId) if (!account) { @@ -162,6 +223,23 @@ class CcrRelayService { ) const response = await axios(requestConfig) + // 📬 请求已发送成功,立即释放队列锁(无需等待响应处理完成) + // 因为 Claude API 限流基于请求发送时刻计算(RPM),不是请求完成时刻 + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for CCR account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for CCR account ${accountId}:`, + releaseError.message + ) + } + } + // 移除监听器(请求成功完成) if (clientRequest) { clientRequest.removeListener('close', handleClientDisconnect) @@ -233,6 +311,21 @@ class CcrRelayService { ) throw error + } finally { + // 📬 释放用户消息队列锁(兜底,正常情况下已在请求发送后提前释放) + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for CCR account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for CCR account ${accountId}:`, + releaseError.message + ) + } + } } } @@ -248,7 +341,77 @@ class CcrRelayService { options = {} ) { let account = null + let queueLockAcquired = false + let queueRequestId = null + try { + // 📬 用户消息队列处理 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error( + '❌ accountId missing for queue lock in CCR relayStreamRequestWithUsageCapture' + ) + throw new Error('accountId missing for queue lock') + } + const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续��计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + stream: true, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for CCR account ${accountId} (stream)`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null + responseStream.writeHead(statusCode, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: existingConnection || 'keep-alive', + 'x-user-message-queue-error': errorType + }) + } + const errorEvent = `event: error\ndata: ${JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + })}\n\n` + responseStream.write(errorEvent) + responseStream.write('data: [DONE]\n\n') + responseStream.end() + return + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for CCR account ${accountId} (stream), requestId: ${queueRequestId}` + ) + } + } + // 获取账户信息 account = await ccrAccountService.getAccount(accountId) if (!account) { @@ -296,14 +459,53 @@ class CcrRelayService { accountId, usageCallback, streamTransformer, - options + options, + // 📬 回调:在收到响应头时释放队列锁 + async () => { + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for CCR stream account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for CCR stream account ${accountId}:`, + releaseError.message + ) + } + } + } ) // 更新最后使用时间 await this._updateLastUsedTime(accountId) } catch (error) { - logger.error(`❌ CCR stream relay failed (Account: ${account?.name || accountId}):`, error) + // 客户端主动断开连接是正常情况,使用 INFO 级别 + if (error.message === 'Client disconnected') { + logger.info( + `🔌 CCR stream relay ended: Client disconnected (Account: ${account?.name || accountId})` + ) + } else { + logger.error(`❌ CCR stream relay failed (Account: ${account?.name || accountId}):`, error) + } throw error + } finally { + // 📬 释放用户消息队列锁(兜底,正常情况下已在收到响应头后提前释放) + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for CCR stream account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for CCR stream account ${accountId}:`, + releaseError.message + ) + } + } } } @@ -317,7 +519,8 @@ class CcrRelayService { accountId, usageCallback, streamTransformer = null, - requestOptions = {} + requestOptions = {}, + onResponseHeaderReceived = null ) { return new Promise((resolve, reject) => { let aborted = false @@ -380,8 +583,11 @@ class CcrRelayService { // 发送请求 const request = axios(requestConfig) + // 注意:使用 .then(async ...) 模式处理响应 + // - 内部的 releaseQueueLock 有独立的 try-catch,不会导致未捕获异常 + // - queueLockAcquired = false 的赋值会在 finally 执行前完成(JS 单线程保证) request - .then((response) => { + .then(async (response) => { logger.debug(`🌊 CCR stream response status: ${response.status}`) // 错误响应处理 @@ -404,10 +610,13 @@ class CcrRelayService { // 设置错误响应的状态码和响应头 if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null const errorHeaders = { 'Content-Type': response.headers['content-type'] || 'application/json', 'Cache-Control': 'no-cache', - Connection: 'keep-alive' + Connection: existingConnection || 'keep-alive' } // 避免 Transfer-Encoding 冲突,让 Express 自动处理 delete errorHeaders['Transfer-Encoding'] @@ -417,13 +626,13 @@ class CcrRelayService { // 直接透传错误数据,不进行包装 response.data.on('data', (chunk) => { - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.write(chunk) } }) response.data.on('end', () => { - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.end() } resolve() // 不抛出异常,正常完成流处理 @@ -431,6 +640,19 @@ class CcrRelayService { return } + // 📬 收到成功响应头(HTTP 200),调用回调释放队列锁 + // 此时请求已被 Claude API 接受并计入 RPM 配额,无需等待响应完成 + if (onResponseHeaderReceived && typeof onResponseHeaderReceived === 'function') { + try { + await onResponseHeaderReceived() + } catch (callbackError) { + logger.error( + `❌ Failed to execute onResponseHeaderReceived callback for CCR stream account ${accountId}:`, + callbackError.message + ) + } + } + // 成功响应,检查并移除错误状态 ccrAccountService.isAccountRateLimited(accountId).then((isRateLimited) => { if (isRateLimited) { @@ -444,11 +666,20 @@ class CcrRelayService { }) // 设置响应头 + // ⚠️ 关键修复:尊重 auth.js 提前设置的 Connection: close if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null + if (existingConnection) { + logger.debug( + `🔌 [CCR Stream] Preserving existing Connection header: ${existingConnection}` + ) + } const headers = { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive', + Connection: existingConnection || 'keep-alive', 'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Headers': 'Cache-Control' } @@ -487,12 +718,17 @@ class CcrRelayService { } // 写入到响应流 - if (outputLine && !responseStream.destroyed) { + if (outputLine && isStreamWritable(responseStream)) { responseStream.write(`${outputLine}\n`) + } else if (outputLine) { + // 客户端连接已断开,记录警告 + logger.warn( + `⚠️ [CCR] Client disconnected during stream, skipping data for account: ${accountId}` + ) } } else { // 空行也需要传递 - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.write('\n') } } @@ -503,10 +739,6 @@ class CcrRelayService { }) response.data.on('end', () => { - if (!responseStream.destroyed) { - responseStream.end() - } - // 如果收集到使用统计数据,调用回调 if (usageCallback && Object.keys(collectedUsage).length > 0) { try { @@ -518,12 +750,26 @@ class CcrRelayService { } } - resolve() + if (isStreamWritable(responseStream)) { + // 等待数据完全 flush 到客户端后再 resolve + responseStream.end(() => { + logger.debug( + `🌊 CCR stream response completed and flushed | bytesWritten: ${responseStream.bytesWritten || 'unknown'}` + ) + resolve() + }) + } else { + // 连接已断开,记录警告 + logger.warn( + `⚠️ [CCR] Client disconnected before stream end, data may not have been received | account: ${accountId}` + ) + resolve() + } }) response.data.on('error', (err) => { logger.error('❌ Stream data error:', err) - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.end() } reject(err) @@ -555,7 +801,7 @@ class CcrRelayService { } } - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.write(`data: ${JSON.stringify(errorResponse)}\n\n`) responseStream.end() } diff --git a/src/services/claudeAccountService.js b/src/services/claudeAccountService.js index 77630364..35ce9cff 100644 --- a/src/services/claudeAccountService.js +++ b/src/services/claudeAccountService.js @@ -91,7 +91,8 @@ class ClaudeAccountService { useUnifiedClientId = false, // 是否使用统一的客户端标识 unifiedClientId = '', // 统一的客户端标识 expiresAt = null, // 账户订阅到期时间 - extInfo = null // 额外扩展信息 + extInfo = null, // 额外扩展信息 + maxConcurrency = 0 // 账户级用户消息串行队列:0=使用全局配置,>0=强制启用串行 } = options const accountId = uuidv4() @@ -136,7 +137,9 @@ class ClaudeAccountService { // 账户订阅到期时间 subscriptionExpiresAt: expiresAt || '', // 扩展信息 - extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '' + extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '', + // 账户级用户消息串行队列限制 + maxConcurrency: maxConcurrency.toString() } } else { // 兼容旧格式 @@ -168,7 +171,9 @@ class ClaudeAccountService { // 账户订阅到期时间 subscriptionExpiresAt: expiresAt || '', // 扩展信息 - extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '' + extInfo: normalizedExtInfo ? JSON.stringify(normalizedExtInfo) : '', + // 账户级用户消息串行队列限制 + maxConcurrency: maxConcurrency.toString() } } @@ -574,7 +579,9 @@ class ClaudeAccountService { // 添加停止原因 stoppedReason: account.stoppedReason || null, // 扩展信息 - extInfo: parsedExtInfo + extInfo: parsedExtInfo, + // 账户级用户消息串行队列限制 + maxConcurrency: parseInt(account.maxConcurrency || '0', 10) } }) ) @@ -666,7 +673,8 @@ class ClaudeAccountService { 'useUnifiedClientId', 'unifiedClientId', 'subscriptionExpiresAt', - 'extInfo' + 'extInfo', + 'maxConcurrency' ] const updatedData = { ...accountData } let shouldClearAutoStopFields = false @@ -681,7 +689,7 @@ class ClaudeAccountService { updatedData[field] = this._encryptSensitiveData(value) } else if (field === 'proxy') { updatedData[field] = value ? JSON.stringify(value) : '' - } else if (field === 'priority') { + } else if (field === 'priority' || field === 'maxConcurrency') { updatedData[field] = value.toString() } else if (field === 'subscriptionInfo') { // 处理订阅信息更新 diff --git a/src/services/claudeConsoleRelayService.js b/src/services/claudeConsoleRelayService.js index 08e56653..31e8af83 100644 --- a/src/services/claudeConsoleRelayService.js +++ b/src/services/claudeConsoleRelayService.js @@ -9,6 +9,9 @@ const { sanitizeErrorMessage, isAccountDisabledError } = require('../utils/errorSanitizer') +const userMessageQueueService = require('./userMessageQueueService') +const { isStreamWritable } = require('../utils/streamHelper') +const { filterForClaude } = require('../utils/headerFilter') class ClaudeConsoleRelayService { constructor() { @@ -29,8 +32,68 @@ class ClaudeConsoleRelayService { let account = null const requestId = uuidv4() // 用于并发追踪 let concurrencyAcquired = false + let queueLockAcquired = false + let queueRequestId = null try { + // 📬 用户消息队列处理:如果是用户消息请求,需要获取队列锁 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error('❌ accountId missing for queue lock in console relayRequest') + throw new Error('accountId missing for queue lock') + } + const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + apiKeyName: apiKeyData.name, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for console account ${accountId}, key: ${apiKeyData.name}`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + return { + statusCode, + headers: { + 'Content-Type': 'application/json', + 'x-user-message-queue-error': errorType + }, + body: JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + }), + accountId + } + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for console account ${accountId}, requestId: ${queueRequestId}` + ) + } + } + // 获取账户信息 account = await claudeConsoleAccountService.getAccount(accountId) if (!account) { @@ -203,6 +266,23 @@ class ClaudeConsoleRelayService { ) const response = await axios(requestConfig) + // 📬 请求已发送成功,立即释放队列锁(无需等待响应处理完成) + // 因为 Claude API 限流基于请求发送时刻计算(RPM),不是请求完成时刻 + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for console account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for console account ${accountId}:`, + releaseError.message + ) + } + } + // 移除监听器(请求成功完成) if (clientRequest) { clientRequest.removeListener('close', handleClientDisconnect) @@ -366,6 +446,21 @@ class ClaudeConsoleRelayService { ) } } + + // 📬 释放用户消息队列锁(兜底,正常情况下已在请求发送后提前释放) + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for console account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for account ${accountId}:`, + releaseError.message + ) + } + } } } @@ -384,8 +479,71 @@ class ClaudeConsoleRelayService { const requestId = uuidv4() // 用于并发追踪 let concurrencyAcquired = false let leaseRefreshInterval = null // 租约刷新定时器 + let queueLockAcquired = false + let queueRequestId = null try { + // 📬 用户消息队列处理:如果是用户消息请求,需要获取队列锁 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error( + '❌ accountId missing for queue lock in console relayStreamRequestWithUsageCapture' + ) + throw new Error('accountId missing for queue lock') + } + const queueResult = await userMessageQueueService.acquireQueueLock(accountId) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + stream: true, + apiKeyName: apiKeyData.name, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for console account ${accountId} (stream), key: ${apiKeyData.name}`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null + responseStream.writeHead(statusCode, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: existingConnection || 'keep-alive', + 'x-user-message-queue-error': errorType + }) + } + const errorEvent = `event: error\ndata: ${JSON.stringify({ type: 'error', error: { type: errorType, code: errorCode, message: errorMessage } })}\n\n` + responseStream.write(errorEvent) + responseStream.write('data: [DONE]\n\n') + responseStream.end() + return + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for console account ${accountId} (stream), requestId: ${queueRequestId}` + ) + } + } + // 获取账户信息 account = await claudeConsoleAccountService.getAccount(accountId) if (!account) { @@ -483,16 +641,40 @@ class ClaudeConsoleRelayService { accountId, usageCallback, streamTransformer, - options + options, + // 📬 回调:在收到响应头时释放队列锁 + async () => { + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for console stream account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for console stream account ${accountId}:`, + releaseError.message + ) + } + } + } ) // 更新最后使用时间 await this._updateLastUsedTime(accountId) } catch (error) { - logger.error( - `❌ Claude Console stream relay failed (Account: ${account?.name || accountId}):`, - error - ) + // 客户端主动断开连接是正常情况,使用 INFO 级别 + if (error.message === 'Client disconnected') { + logger.info( + `🔌 Claude Console stream relay ended: Client disconnected (Account: ${account?.name || accountId})` + ) + } else { + logger.error( + `❌ Claude Console stream relay failed (Account: ${account?.name || accountId}):`, + error + ) + } throw error } finally { // 🛑 清理租约刷新定时器 @@ -517,6 +699,21 @@ class ClaudeConsoleRelayService { ) } } + + // 📬 释放用户消息队列锁(兜底,正常情况下已在收到响应头后提前释放) + if (queueLockAcquired && queueRequestId && accountId) { + try { + await userMessageQueueService.releaseQueueLock(accountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for console stream account ${accountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for stream account ${accountId}:`, + releaseError.message + ) + } + } } } @@ -530,7 +727,8 @@ class ClaudeConsoleRelayService { accountId, usageCallback, streamTransformer = null, - requestOptions = {} + requestOptions = {}, + onResponseHeaderReceived = null ) { return new Promise((resolve, reject) => { let aborted = false @@ -593,8 +791,11 @@ class ClaudeConsoleRelayService { // 发送请求 const request = axios(requestConfig) + // 注意:使用 .then(async ...) 模式处理响应 + // - 内部的 releaseQueueLock 有独立的 try-catch,不会导致未捕获异常 + // - queueLockAcquired = false 的赋值会在 finally 执行前完成(JS 单线程保证) request - .then((response) => { + .then(async (response) => { logger.debug(`🌊 Claude Console Claude stream response status: ${response.status}`) // 错误响应处理 @@ -682,7 +883,7 @@ class ClaudeConsoleRelayService { `🧹 [Stream] [SANITIZED] Error response to client: ${JSON.stringify(sanitizedError)}` ) - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.write(JSON.stringify(sanitizedError)) responseStream.end() } @@ -690,7 +891,7 @@ class ClaudeConsoleRelayService { const sanitizedText = sanitizeErrorMessage(errorDataForCheck) logger.error(`🧹 [Stream] [SANITIZED] Error response to client: ${sanitizedText}`) - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.write(sanitizedText) responseStream.end() } @@ -701,6 +902,19 @@ class ClaudeConsoleRelayService { return } + // 📬 收到成功响应头(HTTP 200),调用回调释放队列锁 + // 此时请求已被 Claude API 接受并计入 RPM 配额,无需等待响应完成 + if (onResponseHeaderReceived && typeof onResponseHeaderReceived === 'function') { + try { + await onResponseHeaderReceived() + } catch (callbackError) { + logger.error( + `❌ Failed to execute onResponseHeaderReceived callback for console stream account ${accountId}:`, + callbackError.message + ) + } + } + // 成功响应,检查并移除错误状态 claudeConsoleAccountService.isAccountRateLimited(accountId).then((isRateLimited) => { if (isRateLimited) { @@ -714,11 +928,22 @@ class ClaudeConsoleRelayService { }) // 设置响应头 + // ⚠️ 关键修复:尊重 auth.js 提前设置的 Connection: close + // 当并发队列功能启用时,auth.js 会设置 Connection: close 来禁用 Keep-Alive if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null + const connectionHeader = existingConnection || 'keep-alive' + if (existingConnection) { + logger.debug( + `🔌 [Console Stream] Preserving existing Connection header: ${existingConnection}` + ) + } responseStream.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive', + Connection: connectionHeader, 'X-Accel-Buffering': 'no' }) } @@ -744,20 +969,33 @@ class ClaudeConsoleRelayService { buffer = lines.pop() || '' // 转发数据并解析usage - if (lines.length > 0 && !responseStream.destroyed) { - const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '') + if (lines.length > 0) { + // 检查流是否可写(客户端连接是否有效) + if (isStreamWritable(responseStream)) { + const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '') - // 应用流转换器如果有 - if (streamTransformer) { - const transformed = streamTransformer(linesToForward) - if (transformed) { - responseStream.write(transformed) + // 应用流转换器如果有 + let dataToWrite = linesToForward + if (streamTransformer) { + const transformed = streamTransformer(linesToForward) + if (transformed) { + dataToWrite = transformed + } else { + dataToWrite = null + } + } + + if (dataToWrite) { + responseStream.write(dataToWrite) } } else { - responseStream.write(linesToForward) + // 客户端连接已断开,记录警告(但仍继续解析usage) + logger.warn( + `⚠️ [Console] Client disconnected during stream, skipping ${lines.length} lines for account: ${account?.name || accountId}` + ) } - // 解析SSE数据寻找usage信息 + // 解析SSE数据寻找usage信息(无论连接状态如何) for (const line of lines) { if (line.startsWith('data:')) { const jsonStr = line.slice(5).trimStart() @@ -865,7 +1103,7 @@ class ClaudeConsoleRelayService { `❌ Error processing Claude Console stream data (Account: ${account?.name || accountId}):`, error ) - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { // 如果有 streamTransformer(如测试请求),使用前端期望的格式 if (streamTransformer) { responseStream.write( @@ -888,7 +1126,7 @@ class ClaudeConsoleRelayService { response.data.on('end', () => { try { // 处理缓冲区中剩余的数据 - if (buffer.trim() && !responseStream.destroyed) { + if (buffer.trim() && isStreamWritable(responseStream)) { if (streamTransformer) { const transformed = streamTransformer(buffer) if (transformed) { @@ -937,12 +1175,33 @@ class ClaudeConsoleRelayService { } // 确保流正确结束 - if (!responseStream.destroyed) { - responseStream.end() - } + if (isStreamWritable(responseStream)) { + // 📊 诊断日志:流结束前状态 + logger.info( + `📤 [STREAM] Ending response | destroyed: ${responseStream.destroyed}, ` + + `socketDestroyed: ${responseStream.socket?.destroyed}, ` + + `socketBytesWritten: ${responseStream.socket?.bytesWritten || 0}` + ) - logger.debug('🌊 Claude Console Claude stream response completed') - resolve() + // 禁用 Nagle 算法确保数据立即发送 + if (responseStream.socket && !responseStream.socket.destroyed) { + responseStream.socket.setNoDelay(true) + } + + // 等待数据完全 flush 到客户端后再 resolve + responseStream.end(() => { + logger.info( + `✅ [STREAM] Response ended and flushed | socketBytesWritten: ${responseStream.socket?.bytesWritten || 'unknown'}` + ) + resolve() + }) + } else { + // 连接已断开,记录警告 + logger.warn( + `⚠️ [Console] Client disconnected before stream end, data may not have been received | account: ${account?.name || accountId}` + ) + resolve() + } } catch (error) { logger.error('❌ Error processing stream end:', error) reject(error) @@ -954,7 +1213,7 @@ class ClaudeConsoleRelayService { `❌ Claude Console stream error (Account: ${account?.name || accountId}):`, error ) - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { // 如果有 streamTransformer(如测试请求),使用前端期望的格式 if (streamTransformer) { responseStream.write( @@ -1002,14 +1261,17 @@ class ClaudeConsoleRelayService { // 发送错误响应 if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null responseStream.writeHead(error.response?.status || 500, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive' + Connection: existingConnection || 'keep-alive' }) } - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { // 如果有 streamTransformer(如测试请求),使用前端期望的格式 if (streamTransformer) { responseStream.write( @@ -1041,30 +1303,9 @@ class ClaudeConsoleRelayService { // 🔧 过滤客户端请求头 _filterClientHeaders(clientHeaders) { - const sensitiveHeaders = [ - 'content-type', - 'user-agent', - 'authorization', - 'x-api-key', - 'host', - 'content-length', - 'connection', - 'proxy-authorization', - 'content-encoding', - 'transfer-encoding', - 'anthropic-version' - ] - - const filteredHeaders = {} - - Object.keys(clientHeaders || {}).forEach((key) => { - const lowerKey = key.toLowerCase() - if (!sensitiveHeaders.includes(lowerKey)) { - filteredHeaders[key] = clientHeaders[key] - } - }) - - return filteredHeaders + // 使用统一的 headerFilter 工具类(白名单模式) + // 与 claudeRelayService 保持一致,避免透传 CDN headers 触发上游 API 安全检查 + return filterForClaude(clientHeaders) } // 🕐 更新最后使用时间 @@ -1179,7 +1420,7 @@ class ClaudeConsoleRelayService { 'Cache-Control': 'no-cache' }) } - if (!responseStream.destroyed && !responseStream.writableEnded) { + if (isStreamWritable(responseStream)) { responseStream.write( `data: ${JSON.stringify({ type: 'test_complete', success: false, error: error.message })}\n\n` ) diff --git a/src/services/claudeRelayConfigService.js b/src/services/claudeRelayConfigService.js index 3b9790ac..4fa2b411 100644 --- a/src/services/claudeRelayConfigService.js +++ b/src/services/claudeRelayConfigService.js @@ -15,6 +15,20 @@ const DEFAULT_CONFIG = { globalSessionBindingEnabled: false, sessionBindingErrorMessage: '你的本地session已污染,请清理后使用。', sessionBindingTtlDays: 30, // 会话绑定 TTL(天),默认30天 + // 用户消息队列配置 + userMessageQueueEnabled: false, // 是否启用用户消息队列(默认关闭) + userMessageQueueDelayMs: 200, // 请求间隔(毫秒) + userMessageQueueTimeoutMs: 5000, // 队列等待超时(毫秒),优化后锁持有时间短无需长等待 + userMessageQueueLockTtlMs: 5000, // 锁TTL(毫秒),请求发送后立即释放无需长TTL + // 并发请求排队配置 + concurrentRequestQueueEnabled: false, // 是否启用并发请求排队(默认关闭) + concurrentRequestQueueMaxSize: 3, // 固定最小排队数(默认3) + concurrentRequestQueueMaxSizeMultiplier: 0, // 并发数的倍数(默认0,仅使用固定值) + concurrentRequestQueueTimeoutMs: 10000, // 排队超时(毫秒,默认10秒) + concurrentRequestQueueMaxRedisFailCount: 5, // 连续 Redis 失败阈值(默认5次) + // 排队健康检查配置 + concurrentRequestQueueHealthCheckEnabled: true, // 是否启用排队健康检查(默认开启) + concurrentRequestQueueHealthThreshold: 0.8, // 健康检查阈值(P90 >= 超时 × 阈值时拒绝新请求) updatedAt: null, updatedBy: null } @@ -100,7 +114,8 @@ class ClaudeRelayConfigService { logger.info(`✅ Claude relay config updated by ${updatedBy}:`, { claudeCodeOnlyEnabled: updatedConfig.claudeCodeOnlyEnabled, - globalSessionBindingEnabled: updatedConfig.globalSessionBindingEnabled + globalSessionBindingEnabled: updatedConfig.globalSessionBindingEnabled, + concurrentRequestQueueEnabled: updatedConfig.concurrentRequestQueueEnabled }) return updatedConfig @@ -283,12 +298,13 @@ class ClaudeRelayConfigService { const account = await accountService.getAccount(accountId) - if (!account || !account.success || !account.data) { + // getAccount() 直接返回账户数据对象或 null,不是 { success, data } 格式 + if (!account) { logger.warn(`Session binding account not found: ${accountId} (${accountType})`) return false } - const accountData = account.data + const accountData = account // 检查账户是否激活 if (accountData.isActive === false || accountData.isActive === 'false') { @@ -315,11 +331,11 @@ class ClaudeRelayConfigService { /** * 验证新会话请求 - * @param {Object} requestBody - 请求体 + * @param {Object} _requestBody - 请求体(预留参数,当前未使用) * @param {string} originalSessionId - 原始会话ID * @returns {Promise} { valid: boolean, error?: string, binding?: object, isNewSession?: boolean } */ - async validateNewSession(requestBody, originalSessionId) { + async validateNewSession(_requestBody, originalSessionId) { const cfg = await this.getConfig() if (!cfg.globalSessionBindingEnabled) { diff --git a/src/services/claudeRelayService.js b/src/services/claudeRelayService.js index 998b14ef..40c6103b 100644 --- a/src/services/claudeRelayService.js +++ b/src/services/claudeRelayService.js @@ -15,6 +15,8 @@ const ClaudeCodeValidator = require('../validators/clients/claudeCodeValidator') const { formatDateWithTimezone } = require('../utils/dateHelper') const requestIdentityService = require('./requestIdentityService') const { createClaudeTestPayload } = require('../utils/testPayloadHelper') +const userMessageQueueService = require('./userMessageQueueService') +const { isStreamWritable } = require('../utils/streamHelper') class ClaudeRelayService { constructor() { @@ -148,6 +150,9 @@ class ClaudeRelayService { options = {} ) { let upstreamRequest = null + let queueLockAcquired = false + let queueRequestId = null + let selectedAccountId = null try { // 调试日志:查看API Key数据 @@ -192,11 +197,80 @@ class ClaudeRelayService { } const { accountId } = accountSelection const { accountType } = accountSelection + selectedAccountId = accountId logger.info( `📤 Processing API request for key: ${apiKeyData.name || apiKeyData.id}, account: ${accountId} (${accountType})${sessionHash ? `, session: ${sessionHash}` : ''}` ) + // 📬 用户消息队列处理:如果是用户消息请求,需要获取队列锁 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error('❌ accountId missing for queue lock in relayRequest') + throw new Error('accountId missing for queue lock') + } + // 获取账户信息以检查账户级串行队列配置 + const accountForQueue = await claudeAccountService.getAccount(accountId) + const accountConfig = accountForQueue + ? { maxConcurrency: parseInt(accountForQueue.maxConcurrency || '0', 10) } + : null + const queueResult = await userMessageQueueService.acquireQueueLock( + accountId, + null, + null, + accountConfig + ) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + apiKeyName: apiKeyData.name, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for account ${accountId}, key: ${apiKeyData.name}`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + return { + statusCode, + headers: { + 'Content-Type': 'application/json', + 'x-user-message-queue-error': errorType + }, + body: JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + }), + accountId + } + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for account ${accountId}, requestId: ${queueRequestId}` + ) + } + } + // 获取账户信息 let account = await claudeAccountService.getAccount(accountId) @@ -271,6 +345,23 @@ class ClaudeRelayService { options ) + // 📬 请求已发送成功,立即释放队列锁(无需等待响应处理完成) + // 因为 Claude API 限流基于请求发送时刻计算(RPM),不是请求完成时刻 + if (queueLockAcquired && queueRequestId && selectedAccountId) { + try { + await userMessageQueueService.releaseQueueLock(selectedAccountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for account ${selectedAccountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for account ${selectedAccountId}:`, + releaseError.message + ) + } + } + response.accountId = accountId response.accountType = accountType @@ -539,6 +630,21 @@ class ClaudeRelayService { error.message ) throw error + } finally { + // 📬 释放用户消息队列锁(兜底,正常情况下已在请求发送后提前释放) + if (queueLockAcquired && queueRequestId && selectedAccountId) { + try { + await userMessageQueueService.releaseQueueLock(selectedAccountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for account ${selectedAccountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for account ${selectedAccountId}:`, + releaseError.message + ) + } + } } } @@ -962,6 +1068,8 @@ class ClaudeRelayService { logger.info(`🔗 指纹是这个: ${headers['User-Agent']}`) + logger.info(`🔗 指纹是这个: ${headers['User-Agent']}`) + // 根据模型和客户端传递的 anthropic-beta 动态设置 header const modelId = requestPayload?.model || body?.model const clientBetaHeader = clientHeaders?.['anthropic-beta'] @@ -1057,8 +1165,6 @@ class ClaudeRelayService { timeout: config.requestTimeout || 600000 } - console.log(options.path) - const req = https.request(options, (res) => { let responseData = Buffer.alloc(0) @@ -1112,7 +1218,6 @@ class ClaudeRelayService { } req.on('error', async (error) => { - console.error(': ❌ ', error) logger.error(`❌ Claude API request error (Account: ${accountId}):`, error.message, { code: error.code, errno: error.errno, @@ -1163,6 +1268,10 @@ class ClaudeRelayService { streamTransformer = null, options = {} ) { + let queueLockAcquired = false + let queueRequestId = null + let selectedAccountId = null + try { // 调试日志:查看API Key数据(流式请求) logger.info('🔍 [Stream] API Key data received:', { @@ -1206,6 +1315,83 @@ class ClaudeRelayService { } const { accountId } = accountSelection const { accountType } = accountSelection + selectedAccountId = accountId + + // 📬 用户消息队列处理:如果是用户消息请求,需要获取队列锁 + if (userMessageQueueService.isUserMessageRequest(requestBody)) { + // 校验 accountId 非空,避免空值污染队列锁键 + if (!accountId || accountId === '') { + logger.error('❌ accountId missing for queue lock in relayStreamRequestWithUsageCapture') + throw new Error('accountId missing for queue lock') + } + // 获取账户信息以检查账户级串行队列配置 + const accountForQueue = await claudeAccountService.getAccount(accountId) + const accountConfig = accountForQueue + ? { maxConcurrency: parseInt(accountForQueue.maxConcurrency || '0', 10) } + : null + const queueResult = await userMessageQueueService.acquireQueueLock( + accountId, + null, + null, + accountConfig + ) + if (!queueResult.acquired && !queueResult.skipped) { + // 区分 Redis 后端错误和队列超时 + const isBackendError = queueResult.error === 'queue_backend_error' + const errorCode = isBackendError ? 'QUEUE_BACKEND_ERROR' : 'QUEUE_TIMEOUT' + const errorType = isBackendError ? 'queue_backend_error' : 'queue_timeout' + const errorMessage = isBackendError + ? 'Queue service temporarily unavailable, please retry later' + : 'User message queue wait timeout, please retry later' + const statusCode = isBackendError ? 500 : 503 + + // 结构化性能日志,用于后续统计 + logger.performance('user_message_queue_error', { + errorType, + errorCode, + accountId, + statusCode, + stream: true, + apiKeyName: apiKeyData.name, + backendError: isBackendError ? queueResult.errorMessage : undefined + }) + + logger.warn( + `📬 User message queue ${errorType} for account ${accountId} (stream), key: ${apiKeyData.name}`, + isBackendError ? { backendError: queueResult.errorMessage } : {} + ) + if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null + responseStream.writeHead(statusCode, { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + Connection: existingConnection || 'keep-alive', + 'x-user-message-queue-error': errorType + }) + } + const errorEvent = `event: error\ndata: ${JSON.stringify({ + type: 'error', + error: { + type: errorType, + code: errorCode, + message: errorMessage + } + })}\n\n` + responseStream.write(errorEvent) + responseStream.write('data: [DONE]\n\n') + responseStream.end() + return + } + if (queueResult.acquired && !queueResult.skipped) { + queueLockAcquired = true + queueRequestId = queueResult.requestId + logger.debug( + `📬 User message queue lock acquired for account ${accountId} (stream), requestId: ${queueRequestId}` + ) + } + } logger.info( `📡 Processing streaming API request with usage capture for key: ${apiKeyData.name || apiKeyData.id}, account: ${accountId} (${accountType})${sessionHash ? `, session: ${sessionHash}` : ''}` @@ -1272,11 +1458,48 @@ class ClaudeRelayService { sessionHash, streamTransformer, options, - isDedicatedOfficialAccount + isDedicatedOfficialAccount, + // 📬 新增回调:在收到响应头时释放队列锁 + async () => { + if (queueLockAcquired && queueRequestId && selectedAccountId) { + try { + await userMessageQueueService.releaseQueueLock(selectedAccountId, queueRequestId) + queueLockAcquired = false // 标记已释放,防止 finally 重复释放 + logger.debug( + `📬 User message queue lock released early for stream account ${selectedAccountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock early for stream account ${selectedAccountId}:`, + releaseError.message + ) + } + } + } ) } catch (error) { - logger.error(`❌ Claude stream relay with usage capture failed:`, error) + // 客户端主动断开连接是正常情况,使用 INFO 级别 + if (error.message === 'Client disconnected') { + logger.info(`🔌 Claude stream relay ended: Client disconnected`) + } else { + logger.error(`❌ Claude stream relay with usage capture failed:`, error) + } throw error + } finally { + // 📬 释放用户消息队列锁(兜底,正常情况下已在收到响应头后提前释放) + if (queueLockAcquired && queueRequestId && selectedAccountId) { + try { + await userMessageQueueService.releaseQueueLock(selectedAccountId, queueRequestId) + logger.debug( + `📬 User message queue lock released in finally for stream account ${selectedAccountId}, requestId: ${queueRequestId}` + ) + } catch (releaseError) { + logger.error( + `❌ Failed to release user message queue lock for stream account ${selectedAccountId}:`, + releaseError.message + ) + } + } } } @@ -1293,7 +1516,8 @@ class ClaudeRelayService { sessionHash, streamTransformer = null, requestOptions = {}, - isDedicatedOfficialAccount = false + isDedicatedOfficialAccount = false, + onResponseStart = null // 📬 新增:收到响应头时的回调,用于提前释放队列锁 ) { // 获取账户信息用于统一 User-Agent const account = await claudeAccountService.getAccount(accountId) @@ -1478,7 +1702,6 @@ class ClaudeRelayService { }) res.on('end', () => { - console.error(': ❌ ', errorData) logger.error( `❌ Claude API error response (Account: ${account?.name || accountId}):`, errorData @@ -1502,7 +1725,7 @@ class ClaudeRelayService { } })() } - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { // 解析 Claude API 返回的错误详情 let errorMessage = `Claude API error: ${res.statusCode}` try { @@ -1540,6 +1763,16 @@ class ClaudeRelayService { return } + // 📬 收到成功响应头(HTTP 200),立即调用回调释放队列锁 + // 此时请求已被 Claude API 接受并计入 RPM 配额,无需等待响应完成 + if (onResponseStart && typeof onResponseStart === 'function') { + try { + await onResponseStart() + } catch (callbackError) { + logger.error('❌ Error in onResponseStart callback:', callbackError.message) + } + } + let buffer = '' const allUsageData = [] // 收集所有的usage事件 let currentUsageData = {} // 当前正在收集的usage数据 @@ -1557,16 +1790,23 @@ class ClaudeRelayService { buffer = lines.pop() || '' // 保留最后的不完整行 // 转发已处理的完整行到客户端 - if (lines.length > 0 && !responseStream.destroyed) { - const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '') - // 如果有流转换器,应用转换 - if (streamTransformer) { - const transformed = streamTransformer(linesToForward) - if (transformed) { - responseStream.write(transformed) + if (lines.length > 0) { + if (isStreamWritable(responseStream)) { + const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '') + // 如果有流转换器,应用转换 + if (streamTransformer) { + const transformed = streamTransformer(linesToForward) + if (transformed) { + responseStream.write(transformed) + } + } else { + responseStream.write(linesToForward) } } else { - responseStream.write(linesToForward) + // 客户端连接已断开,记录警告(但仍继续解析usage) + logger.warn( + `⚠️ [Official] Client disconnected during stream, skipping ${lines.length} lines for account: ${accountId}` + ) } } @@ -1671,7 +1911,7 @@ class ClaudeRelayService { } catch (error) { logger.error('❌ Error processing stream data:', error) // 发送错误但不破坏流,让它自然结束 - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.write('event: error\n') responseStream.write( `data: ${JSON.stringify({ @@ -1687,7 +1927,7 @@ class ClaudeRelayService { res.on('end', async () => { try { // 处理缓冲区中剩余的数据 - if (buffer.trim() && !responseStream.destroyed) { + if (buffer.trim() && isStreamWritable(responseStream)) { if (streamTransformer) { const transformed = streamTransformer(buffer) if (transformed) { @@ -1699,8 +1939,16 @@ class ClaudeRelayService { } // 确保流正确结束 - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { responseStream.end() + logger.debug( + `🌊 Stream end called | bytesWritten: ${responseStream.bytesWritten || 'unknown'}` + ) + } else { + // 连接已断开,记录警告 + logger.warn( + `⚠️ [Official] Client disconnected before stream end, data may not have been received | account: ${account?.name || accountId}` + ) } } catch (error) { logger.error('❌ Error processing stream end:', error) @@ -1898,14 +2146,17 @@ class ClaudeRelayService { } if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null responseStream.writeHead(statusCode, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive' + Connection: existingConnection || 'keep-alive' }) } - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { // 发送 SSE 错误事件 responseStream.write('event: error\n') responseStream.write( @@ -1925,13 +2176,16 @@ class ClaudeRelayService { logger.error(`❌ Claude stream request timeout | Account: ${account?.name || accountId}`) if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null responseStream.writeHead(504, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive' + Connection: existingConnection || 'keep-alive' }) } - if (!responseStream.destroyed) { + if (isStreamWritable(responseStream)) { // 发送 SSE 错误事件 responseStream.write('event: error\n') responseStream.write( @@ -1950,7 +2204,7 @@ class ClaudeRelayService { responseStream.on('close', () => { logger.debug('🔌 Client disconnected, cleaning up stream') if (!req.destroyed) { - req.destroy() + req.destroy(new Error('Client disconnected')) } }) @@ -2222,34 +2476,44 @@ class ClaudeRelayService { } } + // 🔧 准备测试请求的公共逻辑(供 testAccountConnection 和 testAccountConnectionSync 共用) + async _prepareAccountForTest(accountId) { + // 获取账户信息 + const account = await claudeAccountService.getAccount(accountId) + if (!account) { + throw new Error('Account not found') + } + + // 获取有效的访问token + const accessToken = await claudeAccountService.getValidAccessToken(accountId) + if (!accessToken) { + throw new Error('Failed to get valid access token') + } + + // 获取代理配置 + const proxyAgent = await this._getProxyAgent(accountId) + + return { account, accessToken, proxyAgent } + } + // 🧪 测试账号连接(供Admin API使用,直接复用 _makeClaudeStreamRequestWithUsageCapture) - async testAccountConnection(accountId, responseStream) { - const testRequestBody = createClaudeTestPayload('claude-sonnet-4-5-20250929', { stream: true }) + async testAccountConnection(accountId, responseStream, model = 'claude-sonnet-4-5-20250929') { + const testRequestBody = createClaudeTestPayload(model, { stream: true }) try { - // 获取账户信息 - const account = await claudeAccountService.getAccount(accountId) - if (!account) { - throw new Error('Account not found') - } + const { account, accessToken, proxyAgent } = await this._prepareAccountForTest(accountId) logger.info(`🧪 Testing Claude account connection: ${account.name} (${accountId})`) - // 获取有效的访问token - const accessToken = await claudeAccountService.getValidAccessToken(accountId) - if (!accessToken) { - throw new Error('Failed to get valid access token') - } - - // 获取代理配置 - const proxyAgent = await this._getProxyAgent(accountId) - // 设置响应头 if (!responseStream.headersSent) { + const existingConnection = responseStream.getHeader + ? responseStream.getHeader('Connection') + : null responseStream.writeHead(200, { 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', - Connection: 'keep-alive', + Connection: existingConnection || 'keep-alive', 'X-Accel-Buffering': 'no' }) } @@ -2277,7 +2541,7 @@ class ClaudeRelayService { } catch (error) { logger.error(`❌ Test account connection failed:`, error) // 发送错误事件给前端 - if (!responseStream.destroyed && !responseStream.writableEnded) { + if (isStreamWritable(responseStream)) { try { const errorMsg = error.message || '测试失败' responseStream.write(`data: ${JSON.stringify({ type: 'error', error: errorMsg })}\n\n`) @@ -2289,6 +2553,125 @@ class ClaudeRelayService { } } + // 🧪 非流式测试账号连接(供定时任务使用) + // 复用流式请求方法,收集结果后返回 + async testAccountConnectionSync(accountId, model = 'claude-sonnet-4-5-20250929') { + const testRequestBody = createClaudeTestPayload(model, { stream: true }) + const startTime = Date.now() + + try { + // 使用公共方法准备测试所需的账户信息、token 和代理 + const { account, accessToken, proxyAgent } = await this._prepareAccountForTest(accountId) + + logger.info(`🧪 Testing Claude account connection (sync): ${account.name} (${accountId})`) + + // 创建一个收集器来捕获流式响应 + let responseText = '' + let capturedUsage = null + let capturedModel = model + let hasError = false + let errorMessage = '' + + // 创建模拟的响应流对象 + const mockResponseStream = { + headersSent: true, // 跳过设置响应头 + write: (data) => { + // 解析 SSE 数据 + if (typeof data === 'string' && data.startsWith('data: ')) { + try { + const jsonStr = data.replace('data: ', '').trim() + if (jsonStr && jsonStr !== '[DONE]') { + const parsed = JSON.parse(jsonStr) + // 提取文本内容 + if (parsed.type === 'content_block_delta' && parsed.delta?.text) { + responseText += parsed.delta.text + } + // 提取 usage 信息 + if (parsed.type === 'message_delta' && parsed.usage) { + capturedUsage = parsed.usage + } + // 提取模型信息 + if (parsed.type === 'message_start' && parsed.message?.model) { + capturedModel = parsed.message.model + } + // 检测错误 + if (parsed.type === 'error') { + hasError = true + errorMessage = parsed.error?.message || 'Unknown error' + } + } + } catch { + // 忽略解析错误 + } + } + return true + }, + end: () => {}, + on: () => {}, + once: () => {}, + emit: () => {}, + writable: true + } + + // 复用流式请求方法 + await this._makeClaudeStreamRequestWithUsageCapture( + testRequestBody, + accessToken, + proxyAgent, + {}, // clientHeaders - 测试不需要客户端headers + mockResponseStream, + null, // usageCallback - 测试不需要统计 + accountId, + 'claude-official', // accountType + null, // sessionHash - 测试不需要会话 + null, // streamTransformer - 不需要转换,直接解析原始格式 + {}, // requestOptions + false // isDedicatedOfficialAccount + ) + + const latencyMs = Date.now() - startTime + + if (hasError) { + logger.warn(`⚠️ Test completed with error for account: ${account.name} - ${errorMessage}`) + return { + success: false, + error: errorMessage, + latencyMs, + timestamp: new Date().toISOString() + } + } + + logger.info(`✅ Test completed for account: ${account.name} (${latencyMs}ms)`) + + return { + success: true, + message: responseText.substring(0, 200), // 截取前200字符 + latencyMs, + model: capturedModel, + usage: capturedUsage, + timestamp: new Date().toISOString() + } + } catch (error) { + const latencyMs = Date.now() - startTime + logger.error(`❌ Test account connection (sync) failed:`, error.message) + + // 提取错误详情 + let errorMessage = error.message + if (error.response) { + errorMessage = + error.response.data?.error?.message || error.response.statusText || error.message + } + + return { + success: false, + error: errorMessage, + statusCode: error.response?.status, + latencyMs, + timestamp: new Date().toISOString() + } + } + } + // 🎯 健康检查 async healthCheck() { try { diff --git a/src/services/droidRelayService.js b/src/services/droidRelayService.js index e62d5e85..115be7d9 100644 --- a/src/services/droidRelayService.js +++ b/src/services/droidRelayService.js @@ -336,7 +336,12 @@ class DroidRelayService { ) } } catch (error) { - logger.error(`❌ Droid relay error: ${error.message}`, error) + // 客户端主动断开连接是正常情况,使用 INFO 级别 + if (error.message === 'Client disconnected') { + logger.info(`🔌 Droid relay ended: Client disconnected`) + } else { + logger.error(`❌ Droid relay error: ${error.message}`, error) + } const status = error?.response?.status if (status >= 400 && status < 500) { @@ -634,7 +639,7 @@ class DroidRelayService { // 客户端断开连接时清理 clientResponse.on('close', () => { if (req && !req.destroyed) { - req.destroy() + req.destroy(new Error('Client disconnected')) } }) diff --git a/src/services/openaiResponsesRelayService.js b/src/services/openaiResponsesRelayService.js index 04a806b5..688e6ca7 100644 --- a/src/services/openaiResponsesRelayService.js +++ b/src/services/openaiResponsesRelayService.js @@ -426,9 +426,9 @@ class OpenAIResponsesRelayService { const lines = data.split('\n') for (const line of lines) { - if (line.startsWith('data: ')) { + if (line.startsWith('data:')) { try { - const jsonStr = line.slice(6) + const jsonStr = line.slice(5).trim() if (jsonStr === '[DONE]') { continue } diff --git a/src/services/unifiedOpenAIScheduler.js b/src/services/unifiedOpenAIScheduler.js index cedd8b8a..6027df59 100644 --- a/src/services/unifiedOpenAIScheduler.js +++ b/src/services/unifiedOpenAIScheduler.js @@ -9,6 +9,26 @@ class UnifiedOpenAIScheduler { this.SESSION_MAPPING_PREFIX = 'unified_openai_session_mapping:' } + // 🔢 按优先级和最后使用时间排序账户(与 Claude/Gemini 调度保持一致) + _sortAccountsByPriority(accounts) { + return accounts.sort((a, b) => { + const aPriority = Number.parseInt(a.priority, 10) + const bPriority = Number.parseInt(b.priority, 10) + const normalizedAPriority = Number.isFinite(aPriority) ? aPriority : 50 + const normalizedBPriority = Number.isFinite(bPriority) ? bPriority : 50 + + // 首先按优先级排序(数字越小优先级越高) + if (normalizedAPriority !== normalizedBPriority) { + return normalizedAPriority - normalizedBPriority + } + + // 优先级相同时,按最后使用时间排序(最久未使用的优先) + const aLastUsed = new Date(a.lastUsedAt || 0).getTime() + const bLastUsed = new Date(b.lastUsedAt || 0).getTime() + return aLastUsed - bLastUsed + }) + } + // 🔧 辅助方法:检查账户是否可调度(兼容字符串和布尔值) _isSchedulable(schedulable) { // 如果是 undefined 或 null,默认为可调度 @@ -244,13 +264,7 @@ class UnifiedOpenAIScheduler { `🎯 Using bound dedicated ${accountType} account: ${boundAccount.name} (${boundAccount.id}) for API key ${apiKeyData.name}` ) // 更新账户的最后使用时间 - if (accountType === 'openai') { - await openaiAccountService.recordUsage(boundAccount.id, 0) - } else { - await openaiResponsesAccountService.updateAccount(boundAccount.id, { - lastUsedAt: new Date().toISOString() - }) - } + await this.updateAccountLastUsed(boundAccount.id, accountType) return { accountId: boundAccount.id, accountType @@ -292,7 +306,7 @@ class UnifiedOpenAIScheduler { `🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}` ) // 更新账户的最后使用时间 - await openaiAccountService.recordUsage(mappedAccount.accountId, 0) + await this.updateAccountLastUsed(mappedAccount.accountId, mappedAccount.accountType) return mappedAccount } else { logger.warn( @@ -321,12 +335,8 @@ class UnifiedOpenAIScheduler { } } - // 按最后使用时间排序(最久未使用的优先,与 Claude 保持一致) - const sortedAccounts = availableAccounts.sort((a, b) => { - const aLastUsed = new Date(a.lastUsedAt || 0).getTime() - const bLastUsed = new Date(b.lastUsedAt || 0).getTime() - return aLastUsed - bLastUsed // 最久未使用的优先 - }) + // 按优先级和最后使用时间排序(与 Claude/Gemini 调度保持一致) + const sortedAccounts = this._sortAccountsByPriority(availableAccounts) // 选择第一个账户 const selectedAccount = sortedAccounts[0] @@ -344,11 +354,11 @@ class UnifiedOpenAIScheduler { } logger.info( - `🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for API key ${apiKeyData.name}` + `🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}, priority: ${selectedAccount.priority || 50}) for API key ${apiKeyData.name}` ) // 更新账户的最后使用时间 - await openaiAccountService.recordUsage(selectedAccount.accountId, 0) + await this.updateAccountLastUsed(selectedAccount.accountId, selectedAccount.accountType) return { accountId: selectedAccount.accountId, @@ -494,21 +504,6 @@ class UnifiedOpenAIScheduler { return availableAccounts } - // 🔢 按优先级和最后使用时间排序账户(已废弃,改为与 Claude 保持一致,只按最后使用时间排序) - // _sortAccountsByPriority(accounts) { - // return accounts.sort((a, b) => { - // // 首先按优先级排序(数字越小优先级越高) - // if (a.priority !== b.priority) { - // return a.priority - b.priority - // } - - // // 优先级相同时,按最后使用时间排序(最久未使用的优先) - // const aLastUsed = new Date(a.lastUsedAt || 0).getTime() - // const bLastUsed = new Date(b.lastUsedAt || 0).getTime() - // return aLastUsed - bLastUsed - // }) - // } - // 🔍 检查账户是否可用 async _isAccountAvailable(accountId, accountType) { try { @@ -817,7 +812,7 @@ class UnifiedOpenAIScheduler { `🎯 Using sticky session account from group: ${mappedAccount.accountId} (${mappedAccount.accountType})` ) // 更新账户的最后使用时间 - await openaiAccountService.recordUsage(mappedAccount.accountId, 0) + await this.updateAccountLastUsed(mappedAccount.accountId, mappedAccount.accountType) return mappedAccount } } @@ -909,12 +904,8 @@ class UnifiedOpenAIScheduler { throw error } - // 按最后使用时间排序(最久未使用的优先,与 Claude 保持一致) - const sortedAccounts = availableAccounts.sort((a, b) => { - const aLastUsed = new Date(a.lastUsedAt || 0).getTime() - const bLastUsed = new Date(b.lastUsedAt || 0).getTime() - return aLastUsed - bLastUsed // 最久未使用的优先 - }) + // 按优先级和最后使用时间排序(与 Claude/Gemini 调度保持一致) + const sortedAccounts = this._sortAccountsByPriority(availableAccounts) // 选择第一个账户 const selectedAccount = sortedAccounts[0] @@ -932,11 +923,11 @@ class UnifiedOpenAIScheduler { } logger.info( - `🎯 Selected account from group: ${selectedAccount.name} (${selectedAccount.accountId})` + `🎯 Selected account from group: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}, priority: ${selectedAccount.priority || 50})` ) // 更新账户的最后使用时间 - await openaiAccountService.recordUsage(selectedAccount.accountId, 0) + await this.updateAccountLastUsed(selectedAccount.accountId, selectedAccount.accountType) return { accountId: selectedAccount.accountId, @@ -958,9 +949,12 @@ class UnifiedOpenAIScheduler { async updateAccountLastUsed(accountId, accountType) { try { if (accountType === 'openai') { - await openaiAccountService.updateAccount(accountId, { - lastUsedAt: new Date().toISOString() - }) + await openaiAccountService.recordUsage(accountId, 0) + return + } + + if (accountType === 'openai-responses') { + await openaiResponsesAccountService.recordUsage(accountId, 0) } } catch (error) { logger.warn(`⚠️ Failed to update last used time for account ${accountId}:`, error) diff --git a/src/services/userMessageQueueService.js b/src/services/userMessageQueueService.js new file mode 100644 index 00000000..2b4784a2 --- /dev/null +++ b/src/services/userMessageQueueService.js @@ -0,0 +1,359 @@ +/** + * 用户消息队列服务 + * 为 Claude 账户实现基于消息类型的串行排队机制 + * + * 当请求的最后一条消息是用户输入(role: user)时, + * 同一账户的此类请求需要串行等待,并在请求之间添加延迟 + */ + +const { v4: uuidv4 } = require('uuid') +const redis = require('../models/redis') +const config = require('../../config/config') +const logger = require('../utils/logger') + +// 清理任务间隔 +const CLEANUP_INTERVAL_MS = 60000 // 1分钟 + +// 轮询等待配置 +const POLL_INTERVAL_BASE_MS = 50 // 基础轮询间隔 +const POLL_INTERVAL_MAX_MS = 500 // 最大轮询间隔 +const POLL_BACKOFF_FACTOR = 1.5 // 退避因子 + +class UserMessageQueueService { + constructor() { + this.cleanupTimer = null + } + + /** + * 检测请求是否为真正的用户消息请求 + * 区分真正的用户输入和 tool_result 消息 + * + * Claude API 消息格式: + * - 用户文本消息: { role: 'user', content: 'text' } 或 { role: 'user', content: [{ type: 'text', text: '...' }] } + * - 工具结果消息: { role: 'user', content: [{ type: 'tool_result', tool_use_id: '...', content: '...' }] } + * + * @param {Object} requestBody - 请求体 + * @returns {boolean} - 是否为真正的用户消息(排除 tool_result) + */ + isUserMessageRequest(requestBody) { + const messages = requestBody?.messages + if (!Array.isArray(messages) || messages.length === 0) { + return false + } + const lastMessage = messages[messages.length - 1] + + // 检查 role 是否为 user + if (lastMessage?.role !== 'user') { + return false + } + + // 检查 content 是否包含 tool_result 类型 + const { content } = lastMessage + if (Array.isArray(content)) { + // 如果 content 数组中任何元素是 tool_result,则不是真正的用户消息 + const hasToolResult = content.some( + (block) => block?.type === 'tool_result' || block?.type === 'tool_use_result' + ) + if (hasToolResult) { + return false + } + } + + // role 是 user 且不包含 tool_result,是真正的用户消息 + return true + } + + /** + * 获取当前配置(支持 Web 界面配置优先) + * @returns {Promise} 配置对象 + */ + async getConfig() { + // 默认配置(防止 config.userMessageQueue 未定义) + // 注意:优化后的默认值 - 锁持有时间从分钟级降到毫秒级,无需长等待 + const queueConfig = config.userMessageQueue || {} + const defaults = { + enabled: queueConfig.enabled ?? false, + delayMs: queueConfig.delayMs ?? 200, + timeoutMs: queueConfig.timeoutMs ?? 5000, // 从 60000 降到 5000,因为锁持有时间短 + lockTtlMs: queueConfig.lockTtlMs ?? 5000 // 从 120000 降到 5000,5秒足以覆盖请求发送 + } + + // 尝试从 claudeRelayConfigService 获取 Web 界面配置 + try { + const claudeRelayConfigService = require('./claudeRelayConfigService') + const webConfig = await claudeRelayConfigService.getConfig() + + return { + enabled: + webConfig.userMessageQueueEnabled !== undefined + ? webConfig.userMessageQueueEnabled + : defaults.enabled, + delayMs: + webConfig.userMessageQueueDelayMs !== undefined + ? webConfig.userMessageQueueDelayMs + : defaults.delayMs, + timeoutMs: + webConfig.userMessageQueueTimeoutMs !== undefined + ? webConfig.userMessageQueueTimeoutMs + : defaults.timeoutMs, + lockTtlMs: + webConfig.userMessageQueueLockTtlMs !== undefined + ? webConfig.userMessageQueueLockTtlMs + : defaults.lockTtlMs + } + } catch { + // 回退到环境变量配置 + return defaults + } + } + + /** + * 检查功能是否启用 + * @returns {Promise} + */ + async isEnabled() { + const cfg = await this.getConfig() + return cfg.enabled === true + } + + /** + * 获取账户队列锁(阻塞等待) + * @param {string} accountId - 账户ID + * @param {string} requestId - 请求ID(可选,会自动生成) + * @param {number} timeoutMs - 超时时间(可选,使用配置默认值) + * @param {Object} accountConfig - 账户级配置(可选),优先级高于全局配置 + * @param {number} accountConfig.maxConcurrency - 账户级串行队列开关:>0启用,=0使用全局配置 + * @returns {Promise<{acquired: boolean, requestId: string, error?: string}>} + */ + async acquireQueueLock(accountId, requestId = null, timeoutMs = null, accountConfig = null) { + const cfg = await this.getConfig() + + // 账户级配置优先:maxConcurrency > 0 时强制启用,忽略全局开关 + let queueEnabled = cfg.enabled + if (accountConfig && accountConfig.maxConcurrency > 0) { + queueEnabled = true + logger.debug( + `📬 User message queue: account-level queue enabled for account ${accountId} (maxConcurrency=${accountConfig.maxConcurrency})` + ) + } + + if (!queueEnabled) { + return { acquired: true, requestId: requestId || uuidv4(), skipped: true } + } + + const reqId = requestId || uuidv4() + const timeout = timeoutMs || cfg.timeoutMs + const startTime = Date.now() + let retryCount = 0 + + logger.debug(`📬 User message queue: attempting to acquire lock for account ${accountId}`, { + requestId: reqId, + timeoutMs: timeout + }) + + while (Date.now() - startTime < timeout) { + const result = await redis.acquireUserMessageLock( + accountId, + reqId, + cfg.lockTtlMs, + cfg.delayMs + ) + + // 检测 Redis 错误,立即返回系统错误而非继续轮询 + if (result.redisError) { + logger.error(`📬 User message queue: Redis error while acquiring lock`, { + accountId, + requestId: reqId, + errorMessage: result.errorMessage + }) + return { + acquired: false, + requestId: reqId, + error: 'queue_backend_error', + errorMessage: result.errorMessage + } + } + + if (result.acquired) { + logger.debug(`📬 User message queue: lock acquired for account ${accountId}`, { + requestId: reqId, + waitedMs: Date.now() - startTime, + retries: retryCount + }) + return { acquired: true, requestId: reqId } + } + + // 需要等待 + if (result.waitMs > 0) { + // 需要延迟(上一个请求刚完成) + await this._sleep(Math.min(result.waitMs, timeout - (Date.now() - startTime))) + } else { + // 锁被占用,使用指数退避轮询等待 + const basePollInterval = Math.min( + POLL_INTERVAL_BASE_MS * Math.pow(POLL_BACKOFF_FACTOR, retryCount), + POLL_INTERVAL_MAX_MS + ) + // 添加 ±15% 随机抖动,避免高并发下的周期性碰撞 + const jitter = basePollInterval * (0.85 + Math.random() * 0.3) + const pollInterval = Math.min(jitter, POLL_INTERVAL_MAX_MS) + await this._sleep(pollInterval) + retryCount++ + } + } + + // 超时 + logger.warn(`📬 User message queue: timeout waiting for lock`, { + accountId, + requestId: reqId, + timeoutMs: timeout + }) + + return { + acquired: false, + requestId: reqId, + error: 'queue_timeout' + } + } + + /** + * 释放账户队列锁 + * @param {string} accountId - 账户ID + * @param {string} requestId - 请求ID + * @returns {Promise} + */ + async releaseQueueLock(accountId, requestId) { + if (!accountId || !requestId) { + return false + } + + const released = await redis.releaseUserMessageLock(accountId, requestId) + + if (released) { + logger.debug(`📬 User message queue: lock released for account ${accountId}`, { + requestId + }) + } else { + logger.warn(`📬 User message queue: failed to release lock (not owner?)`, { + accountId, + requestId + }) + } + + return released + } + + /** + * 获取队列统计信息 + * @param {string} accountId - 账户ID + * @returns {Promise} + */ + async getQueueStats(accountId) { + return await redis.getUserMessageQueueStats(accountId) + } + + /** + * 服务启动时清理所有残留的队列锁 + * 防止服务重启后旧锁阻塞新请求 + * @returns {Promise} 清理的锁数量 + */ + async cleanupStaleLocks() { + try { + const accountIds = await redis.scanUserMessageQueueLocks() + let cleanedCount = 0 + + for (const accountId of accountIds) { + try { + await redis.forceReleaseUserMessageLock(accountId) + cleanedCount++ + logger.debug(`📬 User message queue: cleaned stale lock for account ${accountId}`) + } catch (error) { + logger.error( + `📬 User message queue: failed to clean lock for account ${accountId}:`, + error + ) + } + } + + if (cleanedCount > 0) { + logger.info(`📬 User message queue: cleaned ${cleanedCount} stale lock(s) on startup`) + } + + return cleanedCount + } catch (error) { + logger.error('📬 User message queue: failed to cleanup stale locks on startup:', error) + return 0 + } + } + + /** + * 启动定时清理任务 + * 始终启动,每次执行时检查配置以支持运行时动态启用/禁用 + */ + startCleanupTask() { + if (this.cleanupTimer) { + return + } + + this.cleanupTimer = setInterval(async () => { + // 每次运行时检查配置,以便在运行时动态启用/禁用 + const currentConfig = await this.getConfig() + if (!currentConfig.enabled) { + logger.debug('📬 User message queue: cleanup skipped (feature disabled)') + return + } + await this._cleanupOrphanLocks() + }, CLEANUP_INTERVAL_MS) + + logger.info('📬 User message queue: cleanup task started') + } + + /** + * 停止定时清理任务 + */ + stopCleanupTask() { + if (this.cleanupTimer) { + clearInterval(this.cleanupTimer) + this.cleanupTimer = null + logger.info('📬 User message queue: cleanup task stopped') + } + } + + /** + * 清理孤儿锁 + * 检测异常情况:锁存在但没有设置过期时间(lockTtlRaw === -1) + * 正常情况下所有锁都应该有 TTL,Redis 会自动过期 + * @private + */ + async _cleanupOrphanLocks() { + try { + const accountIds = await redis.scanUserMessageQueueLocks() + + for (const accountId of accountIds) { + const stats = await redis.getUserMessageQueueStats(accountId) + + // 检测异常情况:锁存在(isLocked=true)但没有过期时间(lockTtlRaw=-1) + // 正常创建的锁都带有 PX 过期时间,如果没有说明是异常状态 + if (stats.isLocked && stats.lockTtlRaw === -1) { + logger.warn( + `📬 User message queue: cleaning up orphan lock without TTL for account ${accountId}`, + { lockHolder: stats.lockHolder } + ) + await redis.forceReleaseUserMessageLock(accountId) + } + } + } catch (error) { + logger.error('📬 User message queue: cleanup task error:', error) + } + } + + /** + * 睡眠辅助函数 + * @param {number} ms - 毫秒 + * @private + */ + _sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)) + } +} + +module.exports = new UserMessageQueueService() diff --git a/src/utils/logger.js b/src/utils/logger.js index df5b5faa..f0202e89 100644 --- a/src/utils/logger.js +++ b/src/utils/logger.js @@ -137,6 +137,7 @@ const createLogFormat = (colorize = false) => { const logFormat = createLogFormat(false) const consoleFormat = createLogFormat(true) +const isTestEnv = process.env.NODE_ENV === 'test' || process.env.JEST_WORKER_ID // 📁 确保日志目录存在并设置权限 if (!fs.existsSync(config.logging.dirname)) { @@ -159,18 +160,20 @@ const createRotateTransport = (filename, level = null) => { transport.level = level } - // 监听轮转事件 - transport.on('rotate', (oldFilename, newFilename) => { - console.log(`📦 Log rotated: ${oldFilename} -> ${newFilename}`) - }) + // 监听轮转事件(测试环境关闭以避免 Jest 退出后输出) + if (!isTestEnv) { + transport.on('rotate', (oldFilename, newFilename) => { + console.log(`📦 Log rotated: ${oldFilename} -> ${newFilename}`) + }) - transport.on('new', (newFilename) => { - console.log(`📄 New log file created: ${newFilename}`) - }) + transport.on('new', (newFilename) => { + console.log(`📄 New log file created: ${newFilename}`) + }) - transport.on('archive', (zipFilename) => { - console.log(`🗜️ Log archived: ${zipFilename}`) - }) + transport.on('archive', (zipFilename) => { + console.log(`🗜️ Log archived: ${zipFilename}`) + }) + } return transport } diff --git a/src/utils/statsHelper.js b/src/utils/statsHelper.js new file mode 100644 index 00000000..ba75bec7 --- /dev/null +++ b/src/utils/statsHelper.js @@ -0,0 +1,105 @@ +/** + * 统计计算工具函数 + * 提供百分位数计算、等待时间统计等通用统计功能 + */ + +/** + * 计算百分位数(使用 nearest-rank 方法) + * @param {number[]} sortedArray - 已排序的数组(升序) + * @param {number} percentile - 百分位数 (0-100) + * @returns {number} 百分位值 + * + * 边界情况说明: + * - percentile=0: 返回最小值 (index=0) + * - percentile=100: 返回最大值 (index=len-1) + * - percentile=50 且 len=2: 返回第一个元素(nearest-rank 向下取) + * + * 算法说明(nearest-rank 方法): + * - index = ceil(percentile / 100 * len) - 1 + * - 示例:len=100, P50 → ceil(50) - 1 = 49(第50个元素,0-indexed) + * - 示例:len=100, P99 → ceil(99) - 1 = 98(第99个元素) + */ +function getPercentile(sortedArray, percentile) { + const len = sortedArray.length + if (len === 0) { + return 0 + } + if (len === 1) { + return sortedArray[0] + } + + // 边界处理:percentile <= 0 返回最小值 + if (percentile <= 0) { + return sortedArray[0] + } + // 边界处理:percentile >= 100 返回最大值 + if (percentile >= 100) { + return sortedArray[len - 1] + } + + const index = Math.ceil((percentile / 100) * len) - 1 + return sortedArray[index] +} + +/** + * 计算等待时间分布统计 + * @param {number[]} waitTimes - 等待时间数组(无需预先排序) + * @returns {Object|null} 统计对象,空数组返回 null + * + * 返回对象包含: + * - sampleCount: 样本数量(始终包含,便于调用方判断可靠性) + * - count: 样本数量(向后兼容) + * - min: 最小值 + * - max: 最大值 + * - avg: 平均值(四舍五入) + * - p50: 50百分位数(中位数) + * - p90: 90百分位数 + * - p99: 99百分位数 + * - sampleSizeWarning: 样本量不足时的警告信息(样本 < 10) + * - p90Unreliable: P90 统计不可靠标记(样本 < 10) + * - p99Unreliable: P99 统计不可靠标记(样本 < 100) + * + * 可靠性标记说明(详见 design.md Decision 6): + * - 样本 < 10: P90 和 P99 都不可靠 + * - 样本 < 100: P99 不可靠(P90 需要 10 个样本,P99 需要 100 个样本) + * - 即使标记为不可靠,仍返回计算值供参考 + */ +function calculateWaitTimeStats(waitTimes) { + if (!waitTimes || waitTimes.length === 0) { + return null + } + + const sorted = [...waitTimes].sort((a, b) => a - b) + const sum = sorted.reduce((a, b) => a + b, 0) + const len = sorted.length + + const stats = { + sampleCount: len, // 新增:始终包含样本数 + count: len, // 向后兼容 + min: sorted[0], + max: sorted[len - 1], + avg: Math.round(sum / len), + p50: getPercentile(sorted, 50), + p90: getPercentile(sorted, 90), + p99: getPercentile(sorted, 99) + } + + // 渐进式可靠性标记(详见 design.md Decision 6) + // 样本 < 10: P90 不可靠(P90 至少需要 ceil(100/10) = 10 个样本) + if (len < 10) { + stats.sampleSizeWarning = 'Results may be inaccurate due to small sample size' + stats.p90Unreliable = true + } + + // 样本 < 100: P99 不可靠(P99 至少需要 ceil(100/1) = 100 个样本) + if (len < 100) { + stats.p99Unreliable = true + } + + return stats +} + +module.exports = { + getPercentile, + calculateWaitTimeStats +} diff --git a/src/utils/streamHelper.js b/src/utils/streamHelper.js new file mode 100644 index 00000000..3d6c679e --- /dev/null +++ b/src/utils/streamHelper.js @@ -0,0 +1,36 @@ +/** + * Stream Helper Utilities + * 流处理辅助工具函数 + */ + +/** + * 检查响应流是否仍然可写(客户端连接是否有效) + * @param {import('http').ServerResponse} stream - HTTP响应流 + * @returns {boolean} 如果流可写返回true,否则返回false + */ +function isStreamWritable(stream) { + if (!stream) { + return false + } + + // 检查流是否已销毁 + if (stream.destroyed) { + return false + } + + // 检查底层socket是否已销毁 + if (stream.socket?.destroyed) { + return false + } + + // 检查流是否已结束写入 + if (stream.writableEnded) { + return false + } + + return true +} + +module.exports = { + isStreamWritable +} diff --git a/tests/concurrencyQueue.integration.test.js b/tests/concurrencyQueue.integration.test.js new file mode 100644 index 00000000..fce15872 --- /dev/null +++ b/tests/concurrencyQueue.integration.test.js @@ -0,0 +1,860 @@ +/** + * 并发请求排队功能集成测试 + * + * 测试分为三个层次: + * 1. Mock 测试 - 测试核心逻辑,不需要真实 Redis + * 2. Redis 方法测试 - 测试 Redis 操作的原子性和正确性 + * 3. 端到端场景测试 - 测试完整的排队流程 + * + * 运行方式: + * - npm test -- concurrencyQueue.integration # 运行所有测试(Mock 部分) + * - REDIS_TEST=1 npm test -- concurrencyQueue.integration # 包含真实 Redis 测试 + */ + +// Mock logger to avoid console output during tests +jest.mock('../src/utils/logger', () => ({ + api: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + database: jest.fn(), + security: jest.fn() +})) + +const redis = require('../src/models/redis') +const claudeRelayConfigService = require('../src/services/claudeRelayConfigService') + +// Helper: sleep function +const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)) + +// Helper: 创建模拟的 req/res 对象 +function createMockReqRes() { + const listeners = {} + const req = { + destroyed: false, + once: jest.fn((event, handler) => { + listeners[`req:${event}`] = handler + }), + removeListener: jest.fn((event) => { + delete listeners[`req:${event}`] + }), + // 触发事件的辅助方法 + emit: (event) => { + const handler = listeners[`req:${event}`] + if (handler) { + handler() + } + } + } + + const res = { + once: jest.fn((event, handler) => { + listeners[`res:${event}`] = handler + }), + removeListener: jest.fn((event) => { + delete listeners[`res:${event}`] + }), + emit: (event) => { + const handler = listeners[`res:${event}`] + if (handler) { + handler() + } + } + } + + return { req, res, listeners } +} + +// ============================================ +// 第一部分:Mock 测试 - waitForConcurrencySlot 核心逻辑 +// ============================================ +describe('ConcurrencyQueue Integration Tests', () => { + describe('Part 1: waitForConcurrencySlot Logic (Mocked)', () => { + // 导入 auth 模块中的 waitForConcurrencySlot + // 由于它是内部函数,我们需要通过测试其行为来验证 + // 这里我们模拟整个流程 + + let mockRedis + + beforeEach(() => { + jest.clearAllMocks() + + // 创建 Redis mock + mockRedis = { + concurrencyCount: {}, + queueCount: {}, + stats: {}, + waitTimes: {}, + globalWaitTimes: [] + } + + // Mock Redis 并发方法 + jest.spyOn(redis, 'incrConcurrency').mockImplementation(async (keyId, requestId, _lease) => { + if (!mockRedis.concurrencyCount[keyId]) { + mockRedis.concurrencyCount[keyId] = new Set() + } + mockRedis.concurrencyCount[keyId].add(requestId) + return mockRedis.concurrencyCount[keyId].size + }) + + jest.spyOn(redis, 'decrConcurrency').mockImplementation(async (keyId, requestId) => { + if (mockRedis.concurrencyCount[keyId]) { + mockRedis.concurrencyCount[keyId].delete(requestId) + return mockRedis.concurrencyCount[keyId].size + } + return 0 + }) + + // Mock 排队计数方法 + jest.spyOn(redis, 'incrConcurrencyQueue').mockImplementation(async (keyId) => { + mockRedis.queueCount[keyId] = (mockRedis.queueCount[keyId] || 0) + 1 + return mockRedis.queueCount[keyId] + }) + + jest.spyOn(redis, 'decrConcurrencyQueue').mockImplementation(async (keyId) => { + mockRedis.queueCount[keyId] = Math.max(0, (mockRedis.queueCount[keyId] || 0) - 1) + return mockRedis.queueCount[keyId] + }) + + jest + .spyOn(redis, 'getConcurrencyQueueCount') + .mockImplementation(async (keyId) => mockRedis.queueCount[keyId] || 0) + + // Mock 统计方法 + jest.spyOn(redis, 'incrConcurrencyQueueStats').mockImplementation(async (keyId, field) => { + if (!mockRedis.stats[keyId]) { + mockRedis.stats[keyId] = {} + } + mockRedis.stats[keyId][field] = (mockRedis.stats[keyId][field] || 0) + 1 + return mockRedis.stats[keyId][field] + }) + + jest.spyOn(redis, 'recordQueueWaitTime').mockResolvedValue(undefined) + jest.spyOn(redis, 'recordGlobalQueueWaitTime').mockResolvedValue(undefined) + }) + + afterEach(() => { + jest.restoreAllMocks() + }) + + describe('Slot Acquisition Flow', () => { + it('should acquire slot immediately when under concurrency limit', async () => { + // 模拟 waitForConcurrencySlot 的行为 + const keyId = 'test-key-1' + const requestId = 'req-1' + const concurrencyLimit = 5 + + // 直接测试 incrConcurrency 的行为 + const count = await redis.incrConcurrency(keyId, requestId, 300) + + expect(count).toBe(1) + expect(count).toBeLessThanOrEqual(concurrencyLimit) + }) + + it('should track multiple concurrent requests correctly', async () => { + const keyId = 'test-key-2' + const concurrencyLimit = 3 + + // 模拟多个并发请求 + const results = [] + for (let i = 1; i <= 5; i++) { + const count = await redis.incrConcurrency(keyId, `req-${i}`, 300) + results.push({ requestId: `req-${i}`, count, exceeds: count > concurrencyLimit }) + } + + // 前3个应该在限制内 + expect(results[0].exceeds).toBe(false) + expect(results[1].exceeds).toBe(false) + expect(results[2].exceeds).toBe(false) + // 后2个超过限制 + expect(results[3].exceeds).toBe(true) + expect(results[4].exceeds).toBe(true) + }) + + it('should release slot and allow next request', async () => { + const keyId = 'test-key-3' + const concurrencyLimit = 1 + + // 第一个请求获取槽位 + const count1 = await redis.incrConcurrency(keyId, 'req-1', 300) + expect(count1).toBe(1) + + // 第二个请求超限 + const count2 = await redis.incrConcurrency(keyId, 'req-2', 300) + expect(count2).toBe(2) + expect(count2).toBeGreaterThan(concurrencyLimit) + + // 释放第二个请求(因为超限) + await redis.decrConcurrency(keyId, 'req-2') + + // 释放第一个请求 + await redis.decrConcurrency(keyId, 'req-1') + + // 现在第三个请求应该能获取 + const count3 = await redis.incrConcurrency(keyId, 'req-3', 300) + expect(count3).toBe(1) + }) + }) + + describe('Queue Count Management', () => { + it('should increment and decrement queue count atomically', async () => { + const keyId = 'test-key-4' + + // 增加排队计数 + const count1 = await redis.incrConcurrencyQueue(keyId, 60000) + expect(count1).toBe(1) + + const count2 = await redis.incrConcurrencyQueue(keyId, 60000) + expect(count2).toBe(2) + + // 减少排队计数 + const count3 = await redis.decrConcurrencyQueue(keyId) + expect(count3).toBe(1) + + const count4 = await redis.decrConcurrencyQueue(keyId) + expect(count4).toBe(0) + }) + + it('should not go below zero on decrement', async () => { + const keyId = 'test-key-5' + + // 直接减少(没有先增加) + const count = await redis.decrConcurrencyQueue(keyId) + expect(count).toBe(0) + }) + + it('should handle concurrent queue operations', async () => { + const keyId = 'test-key-6' + + // 并发增加 + const increments = await Promise.all([ + redis.incrConcurrencyQueue(keyId, 60000), + redis.incrConcurrencyQueue(keyId, 60000), + redis.incrConcurrencyQueue(keyId, 60000) + ]) + + // 所有增量应该是连续的 + const sortedIncrements = [...increments].sort((a, b) => a - b) + expect(sortedIncrements).toEqual([1, 2, 3]) + }) + }) + + describe('Statistics Tracking', () => { + it('should track entered/success/timeout/cancelled stats', async () => { + const keyId = 'test-key-7' + + await redis.incrConcurrencyQueueStats(keyId, 'entered') + await redis.incrConcurrencyQueueStats(keyId, 'entered') + await redis.incrConcurrencyQueueStats(keyId, 'success') + await redis.incrConcurrencyQueueStats(keyId, 'timeout') + await redis.incrConcurrencyQueueStats(keyId, 'cancelled') + + expect(mockRedis.stats[keyId]).toEqual({ + entered: 2, + success: 1, + timeout: 1, + cancelled: 1 + }) + }) + }) + + describe('Client Disconnection Handling', () => { + it('should detect client disconnection via close event', async () => { + const { req } = createMockReqRes() + + let clientDisconnected = false + + // 设置监听器 + req.once('close', () => { + clientDisconnected = true + }) + + // 模拟客户端断开 + req.emit('close') + + expect(clientDisconnected).toBe(true) + }) + + it('should detect pre-destroyed request', () => { + const { req } = createMockReqRes() + req.destroyed = true + + expect(req.destroyed).toBe(true) + }) + }) + + describe('Exponential Backoff Simulation', () => { + it('should increase poll interval with backoff', () => { + const config = { + pollIntervalMs: 200, + maxPollIntervalMs: 2000, + backoffFactor: 1.5, + jitterRatio: 0 // 禁用抖动以便测试 + } + + let interval = config.pollIntervalMs + const intervals = [interval] + + for (let i = 0; i < 5; i++) { + interval = Math.min(interval * config.backoffFactor, config.maxPollIntervalMs) + intervals.push(interval) + } + + // 验证指数增长 + expect(intervals[1]).toBe(300) // 200 * 1.5 + expect(intervals[2]).toBe(450) // 300 * 1.5 + expect(intervals[3]).toBe(675) // 450 * 1.5 + expect(intervals[4]).toBe(1012.5) // 675 * 1.5 + expect(intervals[5]).toBe(1518.75) // 1012.5 * 1.5 + }) + + it('should cap interval at maximum', () => { + const config = { + pollIntervalMs: 1000, + maxPollIntervalMs: 2000, + backoffFactor: 1.5 + } + + let interval = config.pollIntervalMs + + for (let i = 0; i < 10; i++) { + interval = Math.min(interval * config.backoffFactor, config.maxPollIntervalMs) + } + + expect(interval).toBe(2000) + }) + + it('should apply jitter within expected range', () => { + const baseInterval = 1000 + const jitterRatio = 0.2 // ±20% + const results = [] + + for (let i = 0; i < 100; i++) { + const randomValue = Math.random() + const jitter = baseInterval * jitterRatio * (randomValue * 2 - 1) + const finalInterval = baseInterval + jitter + results.push(finalInterval) + } + + const min = Math.min(...results) + const max = Math.max(...results) + + // 所有结果应该在 [800, 1200] 范围内 + expect(min).toBeGreaterThanOrEqual(800) + expect(max).toBeLessThanOrEqual(1200) + }) + }) + }) + + // ============================================ + // 第二部分:并发竞争场景测试 + // ============================================ + describe('Part 2: Concurrent Race Condition Tests', () => { + beforeEach(() => { + jest.clearAllMocks() + }) + + afterEach(() => { + jest.restoreAllMocks() + }) + + describe('Race Condition: Multiple Requests Competing for Same Slot', () => { + it('should handle race condition when multiple requests try to acquire last slot', async () => { + const keyId = 'race-test-1' + const concurrencyLimit = 1 + const concurrencyState = { count: 0, holders: new Set() } + + // 模拟原子的 incrConcurrency + jest.spyOn(redis, 'incrConcurrency').mockImplementation(async (key, reqId) => { + // 模拟原子操作 + concurrencyState.count++ + concurrencyState.holders.add(reqId) + return concurrencyState.count + }) + + jest.spyOn(redis, 'decrConcurrency').mockImplementation(async (key, reqId) => { + if (concurrencyState.holders.has(reqId)) { + concurrencyState.count-- + concurrencyState.holders.delete(reqId) + } + return concurrencyState.count + }) + + // 5个请求同时竞争1个槽位 + const requests = Array.from({ length: 5 }, (_, i) => `req-${i + 1}`) + + const acquireResults = await Promise.all( + requests.map(async (reqId) => { + const count = await redis.incrConcurrency(keyId, reqId, 300) + const acquired = count <= concurrencyLimit + + if (!acquired) { + // 超限,释放 + await redis.decrConcurrency(keyId, reqId) + } + + return { reqId, count, acquired } + }) + ) + + // 只有一个请求应该成功获取槽位 + const successfulAcquires = acquireResults.filter((r) => r.acquired) + expect(successfulAcquires.length).toBe(1) + + // 最终并发计数应该是1 + expect(concurrencyState.count).toBe(1) + }) + + it('should maintain consistency under high contention', async () => { + const keyId = 'race-test-2' + const concurrencyLimit = 3 + const requestCount = 20 + const concurrencyState = { count: 0, maxSeen: 0 } + + jest.spyOn(redis, 'incrConcurrency').mockImplementation(async () => { + concurrencyState.count++ + concurrencyState.maxSeen = Math.max(concurrencyState.maxSeen, concurrencyState.count) + return concurrencyState.count + }) + + jest.spyOn(redis, 'decrConcurrency').mockImplementation(async () => { + concurrencyState.count = Math.max(0, concurrencyState.count - 1) + return concurrencyState.count + }) + + // 模拟多轮请求 + const activeRequests = [] + + for (let i = 0; i < requestCount; i++) { + const count = await redis.incrConcurrency(keyId, `req-${i}`, 300) + + if (count <= concurrencyLimit) { + activeRequests.push(`req-${i}`) + + // 模拟处理时间后释放 + setTimeout(async () => { + await redis.decrConcurrency(keyId, `req-${i}`) + }, Math.random() * 50) + } else { + await redis.decrConcurrency(keyId, `req-${i}`) + } + + // 随机延迟 + await sleep(Math.random() * 10) + } + + // 等待所有请求完成 + await sleep(100) + + // 最大并发不应超过限制 + expect(concurrencyState.maxSeen).toBeLessThanOrEqual(concurrencyLimit + requestCount) // 允许短暂超限 + }) + }) + + describe('Queue Overflow Protection', () => { + it('should reject requests when queue is full', async () => { + const keyId = 'overflow-test-1' + const maxQueueSize = 5 + const queueState = { count: 0 } + + jest.spyOn(redis, 'incrConcurrencyQueue').mockImplementation(async () => { + queueState.count++ + return queueState.count + }) + + jest.spyOn(redis, 'decrConcurrencyQueue').mockImplementation(async () => { + queueState.count = Math.max(0, queueState.count - 1) + return queueState.count + }) + + const results = [] + + // 尝试10个请求进入队列 + for (let i = 0; i < 10; i++) { + const queueCount = await redis.incrConcurrencyQueue(keyId, 60000) + + if (queueCount > maxQueueSize) { + // 队列满,释放并拒绝 + await redis.decrConcurrencyQueue(keyId) + results.push({ index: i, accepted: false }) + } else { + results.push({ index: i, accepted: true, position: queueCount }) + } + } + + const accepted = results.filter((r) => r.accepted) + const rejected = results.filter((r) => !r.accepted) + + expect(accepted.length).toBe(5) + expect(rejected.length).toBe(5) + }) + }) + }) + + // ============================================ + // 第三部分:真实 Redis 集成测试(可选) + // ============================================ + describe('Part 3: Real Redis Integration Tests', () => { + const skipRealRedis = !process.env.REDIS_TEST + + // 辅助函数:检查 Redis 连接 + async function checkRedisConnection() { + try { + const client = redis.getClient() + if (!client) { + return false + } + await client.ping() + return true + } catch { + return false + } + } + + beforeAll(async () => { + if (skipRealRedis) { + console.log('⏭️ Skipping real Redis tests (set REDIS_TEST=1 to enable)') + return + } + + const connected = await checkRedisConnection() + if (!connected) { + console.log('⚠️ Redis not connected, skipping real Redis tests') + } + }) + + // 清理测试数据 + afterEach(async () => { + if (skipRealRedis) { + return + } + + try { + const client = redis.getClient() + if (!client) { + return + } + + // 清理测试键 + const testKeys = await client.keys('concurrency:queue:test-*') + if (testKeys.length > 0) { + await client.del(...testKeys) + } + } catch { + // 忽略清理错误 + } + }) + + describe('Redis Queue Operations', () => { + const testOrSkip = skipRealRedis ? it.skip : it + + testOrSkip('should atomically increment queue count with TTL', async () => { + const keyId = 'test-redis-queue-1' + const timeoutMs = 5000 + + const count1 = await redis.incrConcurrencyQueue(keyId, timeoutMs) + expect(count1).toBe(1) + + const count2 = await redis.incrConcurrencyQueue(keyId, timeoutMs) + expect(count2).toBe(2) + + // 验证 TTL 被设置 + const client = redis.getClient() + const ttl = await client.ttl(`concurrency:queue:${keyId}`) + expect(ttl).toBeGreaterThan(0) + expect(ttl).toBeLessThanOrEqual(Math.ceil(timeoutMs / 1000) + 30) + }) + + testOrSkip('should atomically decrement and delete when zero', async () => { + const keyId = 'test-redis-queue-2' + + await redis.incrConcurrencyQueue(keyId, 60000) + const count = await redis.decrConcurrencyQueue(keyId) + + expect(count).toBe(0) + + // 验证键已删除 + const client = redis.getClient() + const exists = await client.exists(`concurrency:queue:${keyId}`) + expect(exists).toBe(0) + }) + + testOrSkip('should handle concurrent increments correctly', async () => { + const keyId = 'test-redis-queue-3' + const numRequests = 10 + + // 并发增加 + const results = await Promise.all( + Array.from({ length: numRequests }, () => redis.incrConcurrencyQueue(keyId, 60000)) + ) + + // 所有结果应该是 1 到 numRequests + const sorted = [...results].sort((a, b) => a - b) + expect(sorted).toEqual(Array.from({ length: numRequests }, (_, i) => i + 1)) + }) + }) + + describe('Redis Stats Operations', () => { + const testOrSkip = skipRealRedis ? it.skip : it + + testOrSkip('should track queue statistics correctly', async () => { + const keyId = 'test-redis-stats-1' + + await redis.incrConcurrencyQueueStats(keyId, 'entered') + await redis.incrConcurrencyQueueStats(keyId, 'entered') + await redis.incrConcurrencyQueueStats(keyId, 'success') + await redis.incrConcurrencyQueueStats(keyId, 'timeout') + + const stats = await redis.getConcurrencyQueueStats(keyId) + + expect(stats.entered).toBe(2) + expect(stats.success).toBe(1) + expect(stats.timeout).toBe(1) + expect(stats.cancelled).toBe(0) + }) + + testOrSkip('should record and retrieve wait times', async () => { + const keyId = 'test-redis-wait-1' + const waitTimes = [100, 200, 150, 300, 250] + + for (const wt of waitTimes) { + await redis.recordQueueWaitTime(keyId, wt) + } + + const recorded = await redis.getQueueWaitTimes(keyId) + + // 应该按 LIFO 顺序存储 + expect(recorded.length).toBe(5) + expect(recorded[0]).toBe(250) // 最后插入的在前面 + }) + + testOrSkip('should record global wait times', async () => { + const waitTimes = [500, 600, 700] + + for (const wt of waitTimes) { + await redis.recordGlobalQueueWaitTime(wt) + } + + const recorded = await redis.getGlobalQueueWaitTimes() + + expect(recorded.length).toBeGreaterThanOrEqual(3) + }) + }) + + describe('Redis Cleanup Operations', () => { + const testOrSkip = skipRealRedis ? it.skip : it + + testOrSkip('should clear specific queue', async () => { + const keyId = 'test-redis-clear-1' + + await redis.incrConcurrencyQueue(keyId, 60000) + await redis.incrConcurrencyQueue(keyId, 60000) + + const cleared = await redis.clearConcurrencyQueue(keyId) + expect(cleared).toBe(true) + + const count = await redis.getConcurrencyQueueCount(keyId) + expect(count).toBe(0) + }) + + testOrSkip('should clear all queues but preserve stats', async () => { + const keyId1 = 'test-redis-clearall-1' + const keyId2 = 'test-redis-clearall-2' + + // 创建队列和统计 + await redis.incrConcurrencyQueue(keyId1, 60000) + await redis.incrConcurrencyQueue(keyId2, 60000) + await redis.incrConcurrencyQueueStats(keyId1, 'entered') + + // 清理所有队列 + const cleared = await redis.clearAllConcurrencyQueues() + expect(cleared).toBeGreaterThanOrEqual(2) + + // 验证队列已清理 + const count1 = await redis.getConcurrencyQueueCount(keyId1) + const count2 = await redis.getConcurrencyQueueCount(keyId2) + expect(count1).toBe(0) + expect(count2).toBe(0) + + // 统计应该保留 + const stats = await redis.getConcurrencyQueueStats(keyId1) + expect(stats.entered).toBe(1) + }) + }) + }) + + // ============================================ + // 第四部分:配置服务集成测试 + // ============================================ + describe('Part 4: Configuration Service Integration', () => { + beforeEach(() => { + // 清除配置缓存 + claudeRelayConfigService.clearCache() + }) + + afterEach(() => { + jest.restoreAllMocks() + }) + + describe('Queue Configuration', () => { + it('should return default queue configuration', async () => { + jest.spyOn(redis, 'getClient').mockReturnValue(null) + + const config = await claudeRelayConfigService.getConfig() + + expect(config.concurrentRequestQueueEnabled).toBe(false) + expect(config.concurrentRequestQueueMaxSize).toBe(3) + expect(config.concurrentRequestQueueMaxSizeMultiplier).toBe(0) + expect(config.concurrentRequestQueueTimeoutMs).toBe(10000) + }) + + it('should calculate max queue size correctly', async () => { + const testCases = [ + { concurrencyLimit: 5, multiplier: 2, fixedMin: 3, expected: 10 }, // 5*2=10 > 3 + { concurrencyLimit: 1, multiplier: 1, fixedMin: 5, expected: 5 }, // 1*1=1 < 5 + { concurrencyLimit: 10, multiplier: 0.5, fixedMin: 3, expected: 5 }, // 10*0.5=5 > 3 + { concurrencyLimit: 2, multiplier: 1, fixedMin: 10, expected: 10 } // 2*1=2 < 10 + ] + + for (const tc of testCases) { + const maxQueueSize = Math.max(tc.concurrencyLimit * tc.multiplier, tc.fixedMin) + expect(maxQueueSize).toBe(tc.expected) + } + }) + }) + }) + + // ============================================ + // 第五部分:端到端场景测试 + // ============================================ + describe('Part 5: End-to-End Scenario Tests', () => { + describe('Scenario: Claude Code Agent Parallel Tool Calls', () => { + it('should handle burst of parallel tool results', async () => { + // 模拟 Claude Code Agent 发送多个并行工具结果的场景 + const concurrencyLimit = 2 + const maxQueueSize = 5 + + const state = { + concurrency: 0, + queue: 0, + completed: 0, + rejected: 0 + } + + // 模拟 8 个并行工具结果请求 + const requests = Array.from({ length: 8 }, (_, i) => ({ + id: `tool-result-${i + 1}`, + startTime: Date.now() + })) + + // 模拟处理逻辑 + async function processRequest(req) { + // 尝试获取并发槽位 + state.concurrency++ + + if (state.concurrency > concurrencyLimit) { + // 超限,进入队列 + state.concurrency-- + state.queue++ + + if (state.queue > maxQueueSize) { + // 队列满,拒绝 + state.queue-- + state.rejected++ + return { ...req, status: 'rejected', reason: 'queue_full' } + } + + // 等待槽位(模拟) + await sleep(Math.random() * 100) + state.queue-- + state.concurrency++ + } + + // 处理请求 + await sleep(50) // 模拟处理时间 + state.concurrency-- + state.completed++ + + return { ...req, status: 'completed', duration: Date.now() - req.startTime } + } + + const results = await Promise.all(requests.map(processRequest)) + + const completed = results.filter((r) => r.status === 'completed') + const rejected = results.filter((r) => r.status === 'rejected') + + // 大部分请求应该完成 + expect(completed.length).toBeGreaterThan(0) + // 可能有一些被拒绝 + expect(state.rejected).toBe(rejected.length) + + console.log( + ` ✓ Completed: ${completed.length}, Rejected: ${rejected.length}, Max concurrent: ${concurrencyLimit}` + ) + }) + }) + + describe('Scenario: Graceful Degradation', () => { + it('should fallback when Redis fails', async () => { + jest + .spyOn(redis, 'incrConcurrencyQueue') + .mockRejectedValue(new Error('Redis connection lost')) + + // 模拟降级行为:Redis 失败时直接拒绝而不是崩溃 + let result + try { + await redis.incrConcurrencyQueue('fallback-test', 60000) + result = { success: true } + } catch (error) { + // 优雅降级:返回 429 而不是 500 + result = { success: false, fallback: true, error: error.message } + } + + expect(result.fallback).toBe(true) + expect(result.error).toContain('Redis') + }) + }) + + describe('Scenario: Timeout Behavior', () => { + it('should respect queue timeout', async () => { + const timeoutMs = 100 + const startTime = Date.now() + + // 模拟等待超时 + await new Promise((resolve) => setTimeout(resolve, timeoutMs)) + + const elapsed = Date.now() - startTime + expect(elapsed).toBeGreaterThanOrEqual(timeoutMs - 10) // 允许 10ms 误差 + }) + + it('should track timeout statistics', async () => { + const stats = { entered: 0, success: 0, timeout: 0, cancelled: 0 } + + // 模拟多个请求,部分超时 + const requests = [ + { id: 'req-1', willTimeout: false }, + { id: 'req-2', willTimeout: true }, + { id: 'req-3', willTimeout: false }, + { id: 'req-4', willTimeout: true } + ] + + for (const req of requests) { + stats.entered++ + if (req.willTimeout) { + stats.timeout++ + } else { + stats.success++ + } + } + + expect(stats.entered).toBe(4) + expect(stats.success).toBe(2) + expect(stats.timeout).toBe(2) + + // 成功率应该是 50% + const successRate = (stats.success / stats.entered) * 100 + expect(successRate).toBe(50) + }) + }) + }) +}) diff --git a/tests/concurrencyQueue.test.js b/tests/concurrencyQueue.test.js new file mode 100644 index 00000000..ef0ff794 --- /dev/null +++ b/tests/concurrencyQueue.test.js @@ -0,0 +1,278 @@ +/** + * 并发请求排队功能测试 + * 测试排队逻辑中的核心算法:百分位数计算、等待时间统计、指数退避等 + * + * 注意:Redis 方法的测试需要集成测试环境,这里主要测试纯算法逻辑 + */ + +// Mock logger to avoid console output during tests +jest.mock('../src/utils/logger', () => ({ + api: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), + database: jest.fn(), + security: jest.fn() +})) + +// 使用共享的统计工具函数(与生产代码一致) +const { getPercentile, calculateWaitTimeStats } = require('../src/utils/statsHelper') + +describe('ConcurrencyQueue', () => { + describe('Percentile Calculation (nearest-rank method)', () => { + // 直接测试共享工具函数,确保与生产代码行为一致 + it('should return 0 for empty array', () => { + expect(getPercentile([], 50)).toBe(0) + }) + + it('should return single element for single-element array', () => { + expect(getPercentile([100], 50)).toBe(100) + expect(getPercentile([100], 99)).toBe(100) + }) + + it('should return min for percentile 0', () => { + expect(getPercentile([10, 20, 30, 40, 50], 0)).toBe(10) + }) + + it('should return max for percentile 100', () => { + expect(getPercentile([10, 20, 30, 40, 50], 100)).toBe(50) + }) + + it('should calculate P50 correctly for len=10', () => { + // For [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] (len=10) + // P50: ceil(50/100 * 10) - 1 = ceil(5) - 1 = 4 → value at index 4 = 50 + const arr = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] + expect(getPercentile(arr, 50)).toBe(50) + }) + + it('should calculate P90 correctly for len=10', () => { + // For len=10, P90: ceil(90/100 * 10) - 1 = ceil(9) - 1 = 8 → value at index 8 = 90 + const arr = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] + expect(getPercentile(arr, 90)).toBe(90) + }) + + it('should calculate P99 correctly for len=100', () => { + // For len=100, P99: ceil(99/100 * 100) - 1 = ceil(99) - 1 = 98 + const arr = Array.from({ length: 100 }, (_, i) => i + 1) + expect(getPercentile(arr, 99)).toBe(99) + }) + + it('should handle two-element array correctly', () => { + // For [10, 20] (len=2) + // P50: ceil(50/100 * 2) - 1 = ceil(1) - 1 = 0 → value = 10 + expect(getPercentile([10, 20], 50)).toBe(10) + }) + + it('should handle negative percentile as 0', () => { + expect(getPercentile([10, 20, 30], -10)).toBe(10) + }) + + it('should handle percentile > 100 as 100', () => { + expect(getPercentile([10, 20, 30], 150)).toBe(30) + }) + }) + + describe('Wait Time Stats Calculation', () => { + // 直接测试共享工具函数 + it('should return null for empty array', () => { + expect(calculateWaitTimeStats([])).toBeNull() + }) + + it('should return null for null input', () => { + expect(calculateWaitTimeStats(null)).toBeNull() + }) + + it('should return null for undefined input', () => { + expect(calculateWaitTimeStats(undefined)).toBeNull() + }) + + it('should calculate stats correctly for typical data', () => { + const waitTimes = [100, 200, 150, 300, 250, 180, 220, 280, 190, 210] + const stats = calculateWaitTimeStats(waitTimes) + + expect(stats.count).toBe(10) + expect(stats.min).toBe(100) + expect(stats.max).toBe(300) + // Sum: 100+150+180+190+200+210+220+250+280+300 = 2080 + expect(stats.avg).toBe(208) + expect(stats.sampleSizeWarning).toBeUndefined() + }) + + it('should add warning for small sample size (< 10)', () => { + const waitTimes = [100, 200, 300] + const stats = calculateWaitTimeStats(waitTimes) + + expect(stats.count).toBe(3) + expect(stats.sampleSizeWarning).toBe('Results may be inaccurate due to small sample size') + }) + + it('should handle single value', () => { + const stats = calculateWaitTimeStats([500]) + + expect(stats.count).toBe(1) + expect(stats.min).toBe(500) + expect(stats.max).toBe(500) + expect(stats.avg).toBe(500) + expect(stats.p50).toBe(500) + expect(stats.p90).toBe(500) + expect(stats.p99).toBe(500) + }) + + it('should sort input array before calculating', () => { + const waitTimes = [500, 100, 300, 200, 400] + const stats = calculateWaitTimeStats(waitTimes) + + expect(stats.min).toBe(100) + expect(stats.max).toBe(500) + }) + + it('should not modify original array', () => { + const waitTimes = [500, 100, 300] + calculateWaitTimeStats(waitTimes) + + expect(waitTimes).toEqual([500, 100, 300]) + }) + }) + + describe('Exponential Backoff with Jitter', () => { + /** + * 指数退避计算函数(与 auth.js 中的实现一致) + * @param {number} currentInterval - 当前轮询间隔 + * @param {number} backoffFactor - 退避系数 + * @param {number} jitterRatio - 抖动比例 + * @param {number} maxInterval - 最大间隔 + * @param {number} randomValue - 随机值 [0, 1),用于确定性测试 + */ + function calculateNextInterval( + currentInterval, + backoffFactor, + jitterRatio, + maxInterval, + randomValue + ) { + let nextInterval = currentInterval * backoffFactor + // 抖动范围:[-jitterRatio, +jitterRatio] + const jitter = nextInterval * jitterRatio * (randomValue * 2 - 1) + nextInterval = nextInterval + jitter + return Math.max(1, Math.min(nextInterval, maxInterval)) + } + + it('should apply exponential backoff without jitter (randomValue=0.5)', () => { + // randomValue = 0.5 gives jitter = 0 + const next = calculateNextInterval(100, 1.5, 0.2, 1000, 0.5) + expect(next).toBe(150) // 100 * 1.5 = 150 + }) + + it('should apply maximum positive jitter (randomValue=1.0)', () => { + // randomValue = 1.0 gives maximum positive jitter (+20%) + const next = calculateNextInterval(100, 1.5, 0.2, 1000, 1.0) + // 100 * 1.5 = 150, jitter = 150 * 0.2 * 1 = 30 + expect(next).toBe(180) // 150 + 30 + }) + + it('should apply maximum negative jitter (randomValue=0.0)', () => { + // randomValue = 0.0 gives maximum negative jitter (-20%) + const next = calculateNextInterval(100, 1.5, 0.2, 1000, 0.0) + // 100 * 1.5 = 150, jitter = 150 * 0.2 * -1 = -30 + expect(next).toBe(120) // 150 - 30 + }) + + it('should respect maximum interval', () => { + const next = calculateNextInterval(800, 1.5, 0.2, 1000, 1.0) + // 800 * 1.5 = 1200, with +20% jitter = 1440, capped at 1000 + expect(next).toBe(1000) + }) + + it('should never go below 1ms even with extreme negative jitter', () => { + const next = calculateNextInterval(1, 1.0, 0.9, 1000, 0.0) + // 1 * 1.0 = 1, jitter = 1 * 0.9 * -1 = -0.9 + // 1 - 0.9 = 0.1, but Math.max(1, ...) ensures minimum is 1 + expect(next).toBe(1) + }) + + it('should handle zero jitter ratio', () => { + const next = calculateNextInterval(100, 2.0, 0, 1000, 0.0) + expect(next).toBe(200) // Pure exponential, no jitter + }) + + it('should handle large backoff factor', () => { + const next = calculateNextInterval(100, 3.0, 0.1, 1000, 0.5) + expect(next).toBe(300) // 100 * 3.0 = 300 + }) + + describe('jitter distribution', () => { + it('should produce values in expected range', () => { + const results = [] + // Test with various random values + for (let r = 0; r <= 1; r += 0.1) { + results.push(calculateNextInterval(100, 1.5, 0.2, 1000, r)) + } + // All values should be between 120 (150 - 30) and 180 (150 + 30) + expect(Math.min(...results)).toBeGreaterThanOrEqual(120) + expect(Math.max(...results)).toBeLessThanOrEqual(180) + }) + }) + }) + + describe('Queue Size Calculation', () => { + /** + * 最大排队数计算(与 auth.js 中的实现一致) + */ + function calculateMaxQueueSize(concurrencyLimit, multiplier, fixedMin) { + return Math.max(concurrencyLimit * multiplier, fixedMin) + } + + it('should use multiplier when result is larger', () => { + // concurrencyLimit=10, multiplier=2, fixedMin=5 + // max(10*2, 5) = max(20, 5) = 20 + expect(calculateMaxQueueSize(10, 2, 5)).toBe(20) + }) + + it('should use fixed minimum when multiplier result is smaller', () => { + // concurrencyLimit=2, multiplier=1, fixedMin=5 + // max(2*1, 5) = max(2, 5) = 5 + expect(calculateMaxQueueSize(2, 1, 5)).toBe(5) + }) + + it('should handle zero multiplier', () => { + // concurrencyLimit=10, multiplier=0, fixedMin=3 + // max(10*0, 3) = max(0, 3) = 3 + expect(calculateMaxQueueSize(10, 0, 3)).toBe(3) + }) + + it('should handle fractional multiplier', () => { + // concurrencyLimit=10, multiplier=1.5, fixedMin=5 + // max(10*1.5, 5) = max(15, 5) = 15 + expect(calculateMaxQueueSize(10, 1.5, 5)).toBe(15) + }) + }) + + describe('TTL Calculation', () => { + /** + * 排队计数器 TTL 计算(与 redis.js 中的实现一致) + */ + function calculateQueueTtl(timeoutMs, bufferSeconds = 30) { + return Math.ceil(timeoutMs / 1000) + bufferSeconds + } + + it('should calculate TTL with default buffer', () => { + // 60000ms = 60s + 30s buffer = 90s + expect(calculateQueueTtl(60000)).toBe(90) + }) + + it('should round up milliseconds to seconds', () => { + // 61500ms = ceil(61.5) = 62s + 30s = 92s + expect(calculateQueueTtl(61500)).toBe(92) + }) + + it('should handle custom buffer', () => { + // 30000ms = 30s + 60s buffer = 90s + expect(calculateQueueTtl(30000, 60)).toBe(90) + }) + + it('should handle very short timeout', () => { + // 1000ms = 1s + 30s = 31s + expect(calculateQueueTtl(1000)).toBe(31) + }) + }) +}) diff --git a/tests/userMessageQueue.test.js b/tests/userMessageQueue.test.js new file mode 100644 index 00000000..4fd7adb2 --- /dev/null +++ b/tests/userMessageQueue.test.js @@ -0,0 +1,434 @@ +/** + * 用户消息队列服务测试 + * 测试消息类型检测、队列串行行为、延迟间隔、超时处理和功能开关 + */ + +const redis = require('../src/models/redis') +const userMessageQueueService = require('../src/services/userMessageQueueService') + +describe('UserMessageQueueService', () => { + describe('isUserMessageRequest', () => { + it('should return true when last message role is user', () => { + const requestBody = { + messages: [ + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Hi there' }, + { role: 'user', content: 'How are you?' } + ] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(true) + }) + + it('should return false when last message role is assistant', () => { + const requestBody = { + messages: [ + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Hi there' } + ] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should return false when last message contains tool_result', () => { + const requestBody = { + messages: [ + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Let me check that' }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'test-id', + content: 'Tool result' + } + ] + } + ] + } + // tool_result 消息虽然 role 是 user,但不是真正的用户消息 + // 应该返回 false,不进入用户消息队列 + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should return false when last message contains multiple tool_results', () => { + const requestBody = { + messages: [ + { role: 'user', content: 'Run multiple tools' }, + { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'tool-1', + content: 'Result 1' + }, + { + type: 'tool_result', + tool_use_id: 'tool-2', + content: 'Result 2' + } + ] + } + ] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should return true when user message has array content with text type', () => { + const requestBody = { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'Hello, this is a user message' + } + ] + } + ] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(true) + }) + + it('should return true when user message has mixed text and image content', () => { + const requestBody = { + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: 'What is in this image?' + }, + { + type: 'image', + source: { type: 'base64', media_type: 'image/png', data: '...' } + } + ] + } + ] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(true) + }) + + it('should return false when messages is empty', () => { + const requestBody = { messages: [] } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should return false when messages is not an array', () => { + const requestBody = { messages: 'not an array' } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should return false when messages is undefined', () => { + const requestBody = {} + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should return false when requestBody is null', () => { + expect(userMessageQueueService.isUserMessageRequest(null)).toBe(false) + }) + + it('should return false when requestBody is undefined', () => { + expect(userMessageQueueService.isUserMessageRequest(undefined)).toBe(false) + }) + + it('should return false when last message has no role', () => { + const requestBody = { + messages: [{ content: 'Hello' }] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + + it('should handle single user message', () => { + const requestBody = { + messages: [{ role: 'user', content: 'Hello' }] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(true) + }) + + it('should handle single assistant message', () => { + const requestBody = { + messages: [{ role: 'assistant', content: 'Hello' }] + } + expect(userMessageQueueService.isUserMessageRequest(requestBody)).toBe(false) + }) + }) + + describe('getConfig', () => { + it('should return config with expected properties', async () => { + const config = await userMessageQueueService.getConfig() + expect(config).toHaveProperty('enabled') + expect(config).toHaveProperty('delayMs') + expect(config).toHaveProperty('timeoutMs') + expect(config).toHaveProperty('lockTtlMs') + expect(typeof config.enabled).toBe('boolean') + expect(typeof config.delayMs).toBe('number') + expect(typeof config.timeoutMs).toBe('number') + expect(typeof config.lockTtlMs).toBe('number') + }) + }) + + describe('isEnabled', () => { + it('should return boolean', async () => { + const enabled = await userMessageQueueService.isEnabled() + expect(typeof enabled).toBe('boolean') + }) + }) + + describe('acquireQueueLock', () => { + afterEach(() => { + jest.restoreAllMocks() + }) + + it('should acquire lock immediately when no lock exists', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 200, + timeoutMs: 30000, + lockTtlMs: 120000 + }) + jest.spyOn(redis, 'acquireUserMessageLock').mockResolvedValue({ + acquired: true, + waitMs: 0 + }) + + const result = await userMessageQueueService.acquireQueueLock('acct-1', 'req-1') + + expect(result.acquired).toBe(true) + expect(result.requestId).toBe('req-1') + expect(result.error).toBeUndefined() + }) + + it('should skip lock acquisition when queue disabled', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: false, + delayMs: 200, + timeoutMs: 30000, + lockTtlMs: 120000 + }) + const acquireSpy = jest.spyOn(redis, 'acquireUserMessageLock') + + const result = await userMessageQueueService.acquireQueueLock('acct-1') + + expect(result.acquired).toBe(true) + expect(result.skipped).toBe(true) + expect(acquireSpy).not.toHaveBeenCalled() + }) + + it('should generate requestId when not provided', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 200, + timeoutMs: 30000, + lockTtlMs: 120000 + }) + jest.spyOn(redis, 'acquireUserMessageLock').mockResolvedValue({ + acquired: true, + waitMs: 0 + }) + + const result = await userMessageQueueService.acquireQueueLock('acct-1') + + expect(result.acquired).toBe(true) + expect(result.requestId).toBeDefined() + expect(result.requestId.length).toBeGreaterThan(0) + }) + + it('should wait and retry when lock is held by another request', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 200, + timeoutMs: 1000, + lockTtlMs: 120000 + }) + + let callCount = 0 + jest.spyOn(redis, 'acquireUserMessageLock').mockImplementation(async () => { + callCount++ + if (callCount < 3) { + return { acquired: false, waitMs: -1 } // lock held + } + return { acquired: true, waitMs: 0 } + }) + + // Mock sleep to speed up test + jest.spyOn(userMessageQueueService, '_sleep').mockResolvedValue(undefined) + + const result = await userMessageQueueService.acquireQueueLock('acct-1', 'req-1') + + expect(result.acquired).toBe(true) + expect(callCount).toBe(3) + }) + + it('should respect delay when previous request just completed', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 200, + timeoutMs: 1000, + lockTtlMs: 120000 + }) + + let callCount = 0 + jest.spyOn(redis, 'acquireUserMessageLock').mockImplementation(async () => { + callCount++ + if (callCount === 1) { + return { acquired: false, waitMs: 150 } // need to wait 150ms for delay + } + return { acquired: true, waitMs: 0 } + }) + + const sleepSpy = jest.spyOn(userMessageQueueService, '_sleep').mockResolvedValue(undefined) + + const result = await userMessageQueueService.acquireQueueLock('acct-1', 'req-1') + + expect(result.acquired).toBe(true) + expect(sleepSpy).toHaveBeenCalledWith(150) // Should wait for delay + }) + + it('should timeout and return error when wait exceeds timeout', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 200, + timeoutMs: 100, // very short timeout + lockTtlMs: 120000 + }) + + jest.spyOn(redis, 'acquireUserMessageLock').mockResolvedValue({ + acquired: false, + waitMs: -1 // always held + }) + + // Use real timers for timeout test but mock sleep to be instant + jest.spyOn(userMessageQueueService, '_sleep').mockImplementation(async () => { + // Simulate time passing + await new Promise((resolve) => setTimeout(resolve, 60)) + }) + + const result = await userMessageQueueService.acquireQueueLock('acct-1', 'req-1', 100) + + expect(result.acquired).toBe(false) + expect(result.error).toBe('queue_timeout') + }) + }) + + describe('releaseQueueLock', () => { + afterEach(() => { + jest.restoreAllMocks() + }) + + it('should release lock successfully when holding the lock', async () => { + jest.spyOn(redis, 'releaseUserMessageLock').mockResolvedValue(true) + + const result = await userMessageQueueService.releaseQueueLock('acct-1', 'req-1') + + expect(result).toBe(true) + expect(redis.releaseUserMessageLock).toHaveBeenCalledWith('acct-1', 'req-1') + }) + + it('should return false when not holding the lock', async () => { + jest.spyOn(redis, 'releaseUserMessageLock').mockResolvedValue(false) + + const result = await userMessageQueueService.releaseQueueLock('acct-1', 'req-1') + + expect(result).toBe(false) + }) + + it('should return false when accountId is missing', async () => { + const releaseSpy = jest.spyOn(redis, 'releaseUserMessageLock') + + const result = await userMessageQueueService.releaseQueueLock(null, 'req-1') + + expect(result).toBe(false) + expect(releaseSpy).not.toHaveBeenCalled() + }) + + it('should return false when requestId is missing', async () => { + const releaseSpy = jest.spyOn(redis, 'releaseUserMessageLock') + + const result = await userMessageQueueService.releaseQueueLock('acct-1', null) + + expect(result).toBe(false) + expect(releaseSpy).not.toHaveBeenCalled() + }) + }) + + describe('queue serialization behavior', () => { + afterEach(() => { + jest.restoreAllMocks() + }) + + it('should allow different accounts to acquire locks simultaneously', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 200, + timeoutMs: 30000, + lockTtlMs: 120000 + }) + jest.spyOn(redis, 'acquireUserMessageLock').mockResolvedValue({ + acquired: true, + waitMs: 0 + }) + + const [result1, result2] = await Promise.all([ + userMessageQueueService.acquireQueueLock('acct-1', 'req-1'), + userMessageQueueService.acquireQueueLock('acct-2', 'req-2') + ]) + + expect(result1.acquired).toBe(true) + expect(result2.acquired).toBe(true) + }) + + it('should serialize requests for same account', async () => { + jest.spyOn(userMessageQueueService, 'getConfig').mockResolvedValue({ + enabled: true, + delayMs: 50, + timeoutMs: 5000, + lockTtlMs: 120000 + }) + + const lockState = { held: false, holderId: null } + + jest + .spyOn(redis, 'acquireUserMessageLock') + .mockImplementation(async (accountId, requestId) => { + if (!lockState.held) { + lockState.held = true + lockState.holderId = requestId + return { acquired: true, waitMs: 0 } + } + return { acquired: false, waitMs: -1 } + }) + + jest + .spyOn(redis, 'releaseUserMessageLock') + .mockImplementation(async (accountId, requestId) => { + if (lockState.holderId === requestId) { + lockState.held = false + lockState.holderId = null + return true + } + return false + }) + + jest.spyOn(userMessageQueueService, '_sleep').mockResolvedValue(undefined) + + // First request acquires lock + const result1 = await userMessageQueueService.acquireQueueLock('acct-1', 'req-1') + expect(result1.acquired).toBe(true) + + // Second request should fail to acquire (lock held) + const acquirePromise = userMessageQueueService.acquireQueueLock('acct-1', 'req-2', 200) + + // Release first lock + await userMessageQueueService.releaseQueueLock('acct-1', 'req-1') + + // Now second request should acquire + const result2 = await acquirePromise + expect(result2.acquired).toBe(true) + }) + }) +}) diff --git a/web/admin-spa/package-lock.json b/web/admin-spa/package-lock.json index 9405609e..481df56a 100644 --- a/web/admin-spa/package-lock.json +++ b/web/admin-spa/package-lock.json @@ -1157,7 +1157,6 @@ "resolved": "https://registry.npmmirror.com/@types/lodash-es/-/lodash-es-4.17.12.tgz", "integrity": "sha512-0NgftHUcV4v34VhXm8QBSftKVXtbkBG3ViCjs6+eJ5a6y6Mi/jiFGPc1sC7QK+9BFhWrURE3EOggmWaSxL9OzQ==", "license": "MIT", - "peer": true, "dependencies": { "@types/lodash": "*" } @@ -1352,7 +1351,6 @@ "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -1589,7 +1587,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "caniuse-lite": "^1.0.30001726", "electron-to-chromium": "^1.5.173", @@ -3063,15 +3060,13 @@ "version": "4.17.21", "resolved": "https://registry.npmmirror.com/lodash/-/lodash-4.17.21.tgz", "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/lodash-es": { "version": "4.17.21", "resolved": "https://registry.npmmirror.com/lodash-es/-/lodash-es-4.17.21.tgz", "integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/lodash-unified": { "version": "1.0.3", @@ -3623,7 +3618,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -3770,7 +3764,6 @@ "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", "dev": true, "license": "MIT", - "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -4035,7 +4028,6 @@ "integrity": "sha512-33xGNBsDJAkzt0PvninskHlWnTIPgDtTwhg0U38CUoNP/7H6wI2Cz6dUeoNPbjdTdsYTGuiFFASuUOWovH0SyQ==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "@types/estree": "1.0.8" }, @@ -4533,7 +4525,6 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -4924,7 +4915,6 @@ "integrity": "sha512-qO3aKv3HoQC8QKiNSTuUM1l9o/XX3+c+VTgLHbJWHZGeTPVAg2XwazI9UWzoxjIJCGCV2zU60uqMzjeLZuULqA==", "dev": true, "license": "MIT", - "peer": true, "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -5125,7 +5115,6 @@ "resolved": "https://registry.npmmirror.com/vue/-/vue-3.5.18.tgz", "integrity": "sha512-7W4Y4ZbMiQ3SEo+m9lnoNpV9xG7QVMLa+/0RFwwiAVkeYoyGXqWE85jabU4pllJNUzqfLShJ5YLptewhCWUgNA==", "license": "MIT", - "peer": true, "dependencies": { "@vue/compiler-dom": "3.5.18", "@vue/compiler-sfc": "3.5.18", diff --git a/web/admin-spa/src/components/accounts/AccountForm.vue b/web/admin-spa/src/components/accounts/AccountForm.vue index f135a1a5..f23a3b5f 100644 --- a/web/admin-spa/src/components/accounts/AccountForm.vue +++ b/web/admin-spa/src/components/accounts/AccountForm.vue @@ -1320,10 +1320,10 @@ class="rounded-lg bg-blue-100 px-3 py-1 text-xs text-blue-700 transition-colors hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400 dark:hover:bg-blue-900/50" type="button" @click=" - addPresetMapping('claude-sonnet-4-20250514', 'claude-sonnet-4-20250514') + addPresetMapping('claude-opus-4-5-20251101', 'claude-opus-4-5-20251101') " > - + Sonnet 4 + + Opus 4.5 - - + +