From 66fe3cf74a2c4e902f0bd5c83dbd57f5d25e25d2 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 10 Oct 2025 17:16:10 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BC=98=E5=8C=96count=5Ftokens?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E4=B8=8D=E5=8F=97=E5=B9=B6=E5=8F=91=E8=B7=9F?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/middleware/auth.js | 36 +++++++- src/routes/droidRoutes.js | 43 ++++++++++ src/services/droidRelayService.js | 131 +++++++++++++++++++----------- 3 files changed, 161 insertions(+), 49 deletions(-) diff --git a/src/middleware/auth.js b/src/middleware/auth.js index e8a67a43..b89586a4 100644 --- a/src/middleware/auth.js +++ b/src/middleware/auth.js @@ -7,6 +7,37 @@ const redis = require('../models/redis') // const { RateLimiterRedis } = require('rate-limiter-flexible') // 暂时未使用 const ClientValidator = require('../validators/clientValidator') +const TOKEN_COUNT_PATHS = new Set([ + '/v1/messages/count_tokens', + '/api/v1/messages/count_tokens', + '/claude/v1/messages/count_tokens', + '/droid/claude/v1/messages/count_tokens' +]) + +function normalizeRequestPath(value) { + if (!value) { + return '/' + } + const lower = value.split('?')[0].toLowerCase() + const collapsed = lower.replace(/\/{2,}/g, '/') + if (collapsed.length > 1 && collapsed.endsWith('/')) { + return collapsed.slice(0, -1) + } + return collapsed || '/' +} + +function isTokenCountRequest(req) { + const combined = normalizeRequestPath(`${req.baseUrl || ''}${req.path || ''}`) + if (TOKEN_COUNT_PATHS.has(combined)) { + return true + } + const original = normalizeRequestPath(req.originalUrl || '') + if (TOKEN_COUNT_PATHS.has(original)) { + return true + } + return false +} + // 🔑 API Key验证中间件(优化版) const authenticateApiKey = async (req, res, next) => { const startTime = Date.now() @@ -49,8 +80,11 @@ const authenticateApiKey = async (req, res, next) => { }) } + const skipKeyRestrictions = isTokenCountRequest(req) + // 🔒 检查客户端限制(使用新的验证器) if ( + !skipKeyRestrictions && validation.keyData.enableClientRestriction && validation.keyData.allowedClients?.length > 0 ) { @@ -81,7 +115,7 @@ const authenticateApiKey = async (req, res, next) => { // 检查并发限制 const concurrencyLimit = validation.keyData.concurrencyLimit || 0 - if (concurrencyLimit > 0) { + if (!skipKeyRestrictions && concurrencyLimit > 0) { const concurrencyConfig = config.concurrency || {} const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30) const rawRenewInterval = diff --git a/src/routes/droidRoutes.js b/src/routes/droidRoutes.js index ea96d80d..d99cc071 100644 --- a/src/routes/droidRoutes.js +++ b/src/routes/droidRoutes.js @@ -60,6 +60,49 @@ router.post('/claude/v1/messages', authenticateApiKey, async (req, res) => { } }) +router.post('/claude/v1/messages/count_tokens', authenticateApiKey, async (req, res) => { + try { + const requestBody = { ...req.body } + if ('stream' in requestBody) { + delete requestBody.stream + } + const sessionHash = sessionHelper.generateSessionHash(requestBody) + + if (!hasDroidPermission(req.apiKey)) { + logger.security( + `🚫 API Key ${req.apiKey?.id || 'unknown'} 缺少 Droid 权限,拒绝访问 ${req.originalUrl}` + ) + return res.status(403).json({ + error: 'permission_denied', + message: '此 API Key 未启用 Droid 权限' + }) + } + + const result = await droidRelayService.relayRequest( + requestBody, + req.apiKey, + req, + res, + req.headers, + { + endpointType: 'anthropic', + sessionHash, + customPath: '/a/v1/messages/count_tokens', + skipUsageRecord: true, + disableStreaming: true + } + ) + + res.status(result.statusCode).set(result.headers).send(result.body) + } catch (error) { + logger.error('Droid Claude count_tokens relay error:', error) + res.status(500).json({ + error: 'internal_server_error', + message: error.message + }) + } +}) + // OpenAI 端点 - /v1/responses router.post('/openai/v1/responses', authenticateApiKey, async (req, res) => { try { diff --git a/src/services/droidRelayService.js b/src/services/droidRelayService.js index 056c85b1..e274c625 100644 --- a/src/services/droidRelayService.js +++ b/src/services/droidRelayService.js @@ -146,7 +146,13 @@ class DroidRelayService { clientHeaders, options = {} ) { - const { endpointType = 'anthropic', sessionHash = null } = options + const { + endpointType = 'anthropic', + sessionHash = null, + customPath = null, + skipUsageRecord = false, + disableStreaming = false + } = options const keyInfo = apiKeyData || {} const normalizedEndpoint = this._normalizeEndpointType(endpointType) @@ -179,8 +185,12 @@ class DroidRelayService { } // 获取 Factory.ai API URL - const endpoint = this.endpoints[normalizedEndpoint] - const apiUrl = `${this.factoryApiBaseUrl}${endpoint}` + let endpointPath = this.endpoints[normalizedEndpoint] + if (typeof customPath === 'string' && customPath.trim()) { + endpointPath = customPath.startsWith('/') ? customPath : `/${customPath}` + } + + const apiUrl = `${this.factoryApiBaseUrl}${endpointPath}` logger.info(`🌐 Forwarding to Factory.ai: ${apiUrl}`) @@ -207,10 +217,12 @@ class DroidRelayService { } // 处理请求体(注入 system prompt 等) - const processedBody = this._processRequestBody(requestBody, normalizedEndpoint) + const processedBody = this._processRequestBody(requestBody, normalizedEndpoint, { + disableStreaming + }) // 发送请求 - const isStreaming = processedBody.stream !== false + const isStreaming = disableStreaming ? false : processedBody.stream !== false // 根据是否流式选择不同的处理方式 if (isStreaming) { @@ -225,7 +237,8 @@ class DroidRelayService { account, keyInfo, requestBody, - normalizedEndpoint + normalizedEndpoint, + skipUsageRecord ) } else { // 非流式响应:使用 axios @@ -253,7 +266,8 @@ class DroidRelayService { keyInfo, requestBody, clientRequest, - normalizedEndpoint + normalizedEndpoint, + skipUsageRecord ) } } catch (error) { @@ -298,7 +312,8 @@ class DroidRelayService { account, apiKeyData, requestBody, - endpointType + endpointType, + skipUsageRecord = false ) { return new Promise((resolve, reject) => { const url = new URL(apiUrl) @@ -449,28 +464,34 @@ class DroidRelayService { clientResponse.end() // 记录 usage 数据 - const normalizedUsage = await this._recordUsageFromStreamData( - currentUsageData, - apiKeyData, - account, - model - ) + if (!skipUsageRecord) { + const normalizedUsage = await this._recordUsageFromStreamData( + currentUsageData, + apiKeyData, + account, + model + ) - const usageSummary = { - inputTokens: normalizedUsage.input_tokens || 0, - outputTokens: normalizedUsage.output_tokens || 0, - cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0, - cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0 + const usageSummary = { + inputTokens: normalizedUsage.input_tokens || 0, + outputTokens: normalizedUsage.output_tokens || 0, + cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0, + cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0 + } + + await this._applyRateLimitTracking( + clientRequest?.rateLimitInfo, + usageSummary, + model, + ' [stream]' + ) + + logger.success(`✅ Droid stream completed - Account: ${account.name}`) + } else { + logger.success( + `✅ Droid stream completed - Account: ${account.name}, usage recording skipped` + ) } - - await this._applyRateLimitTracking( - clientRequest?.rateLimitInfo, - usageSummary, - model, - ' [stream]' - ) - - logger.success(`✅ Droid stream completed - Account: ${account.name}`) resolveOnce({ statusCode: 200, streaming: true }) }) @@ -801,11 +822,15 @@ class DroidRelayService { /** * 处理请求体(注入 system prompt 等) */ - _processRequestBody(requestBody, endpointType) { + _processRequestBody(requestBody, endpointType, options = {}) { + const { disableStreaming = false } = options const processedBody = { ...requestBody } - // 确保 stream 字段存在 - if (processedBody.stream === undefined) { + if (disableStreaming) { + if ('stream' in processedBody) { + delete processedBody.stream + } + } else if (processedBody.stream === undefined) { processedBody.stream = true } @@ -896,7 +921,8 @@ class DroidRelayService { apiKeyData, requestBody, clientRequest, - endpointType + endpointType, + skipUsageRecord = false ) { const { data } = response @@ -906,26 +932,35 @@ class DroidRelayService { const model = requestBody.model || 'unknown' const normalizedUsage = this._normalizeUsageSnapshot(usage) - await this._recordUsage(apiKeyData, account, model, normalizedUsage) - const totalTokens = this._getTotalTokens(normalizedUsage) + if (!skipUsageRecord) { + await this._recordUsage(apiKeyData, account, model, normalizedUsage) - const usageSummary = { - inputTokens: normalizedUsage.input_tokens || 0, - outputTokens: normalizedUsage.output_tokens || 0, - cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0, - cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0 + const totalTokens = this._getTotalTokens(normalizedUsage) + + const usageSummary = { + inputTokens: normalizedUsage.input_tokens || 0, + outputTokens: normalizedUsage.output_tokens || 0, + cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0, + cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0 + } + + await this._applyRateLimitTracking( + clientRequest?.rateLimitInfo, + usageSummary, + model, + endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]' + ) + + logger.success( + `✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}` + ) + } else { + logger.success( + `✅ Droid request completed - Account: ${account.name}, usage recording skipped` + ) } - await this._applyRateLimitTracking( - clientRequest?.rateLimitInfo, - usageSummary, - model, - endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]' - ) - - logger.success(`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`) - return { statusCode: 200, headers: { 'Content-Type': 'application/json' },