From eb304c7e7018d18a1c7decedd45fa6b2d270595e Mon Sep 17 00:00:00 2001 From: shaw Date: Wed, 8 Oct 2025 08:36:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20openai=E8=BD=AC=E5=8F=91=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0apikey=E9=80=9F=E7=8E=87=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/routes/api.js | 214 ++++++++++++------------------- src/routes/geminiRoutes.js | 50 ++++++++ src/routes/openaiClaudeRoutes.js | 58 +++++++++ src/routes/openaiRoutes.js | 50 ++++++++ src/utils/rateLimitHelper.js | 71 ++++++++++ 5 files changed, 308 insertions(+), 135 deletions(-) create mode 100644 src/utils/rateLimitHelper.js diff --git a/src/routes/api.js b/src/routes/api.js index d4b572d4..f784cae6 100644 --- a/src/routes/api.js +++ b/src/routes/api.js @@ -6,15 +6,37 @@ const ccrRelayService = require('../services/ccrRelayService') const bedrockAccountService = require('../services/bedrockAccountService') const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler') const apiKeyService = require('../services/apiKeyService') -const pricingService = require('../services/pricingService') const { authenticateApiKey } = require('../middleware/auth') const logger = require('../utils/logger') -const redis = require('../models/redis') const { getEffectiveModel, parseVendorPrefixedModel } = require('../utils/modelHelper') const sessionHelper = require('../utils/sessionHelper') +const { updateRateLimitCounters } = require('../utils/rateLimitHelper') const router = express.Router() +function queueRateLimitUpdate(rateLimitInfo, usageSummary, model, context = '') { + if (!rateLimitInfo) { + return Promise.resolve({ totalTokens: 0, totalCost: 0 }) + } + + const label = context ? ` (${context})` : '' + + return updateRateLimitCounters(rateLimitInfo, usageSummary, model) + .then(({ totalTokens, totalCost }) => { + if (totalTokens > 0) { + logger.api(`📊 Updated rate limit token count${label}: +${totalTokens} tokens`) + } + if (typeof totalCost === 'number' && totalCost > 0) { + logger.api(`💰 Updated rate limit cost count${label}: +$${totalCost.toFixed(6)}`) + } + return { totalTokens, totalCost } + }) + .catch((error) => { + logger.error(`❌ Failed to update rate limit counters${label}:`, error) + return { totalTokens: 0, totalCost: 0 } + }) +} + // 🔧 共享的消息处理函数 async function handleMessagesRequest(req, res) { try { @@ -191,35 +213,17 @@ async function handleMessagesRequest(req, res) { logger.error('❌ Failed to record stream usage:', error) }) - // 更新时间窗口内的token计数和费用 - if (req.rateLimitInfo) { - const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens - - // 更新Token计数(向后兼容) - redis - .getClient() - .incrby(req.rateLimitInfo.tokenCountKey, totalTokens) - .catch((error) => { - logger.error('❌ Failed to update rate limit token count:', error) - }) - logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`) - - // 计算并更新费用计数(新功能) - if (req.rateLimitInfo.costCountKey) { - const costInfo = pricingService.calculateCost(usageData, model) - if (costInfo.totalCost > 0) { - redis - .getClient() - .incrbyfloat(req.rateLimitInfo.costCountKey, costInfo.totalCost) - .catch((error) => { - logger.error('❌ Failed to update rate limit cost count:', error) - }) - logger.api( - `💰 Updated rate limit cost count: +$${costInfo.totalCost.toFixed(6)}` - ) - } - } - } + queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens, + outputTokens, + cacheCreateTokens, + cacheReadTokens + }, + model, + 'claude-stream' + ) usageDataCaptured = true logger.api( @@ -300,35 +304,17 @@ async function handleMessagesRequest(req, res) { logger.error('❌ Failed to record stream usage:', error) }) - // 更新时间窗口内的token计数和费用 - if (req.rateLimitInfo) { - const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens - - // 更新Token计数(向后兼容) - redis - .getClient() - .incrby(req.rateLimitInfo.tokenCountKey, totalTokens) - .catch((error) => { - logger.error('❌ Failed to update rate limit token count:', error) - }) - logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`) - - // 计算并更新费用计数(新功能) - if (req.rateLimitInfo.costCountKey) { - const costInfo = pricingService.calculateCost(usageData, model) - if (costInfo.totalCost > 0) { - redis - .getClient() - .incrbyfloat(req.rateLimitInfo.costCountKey, costInfo.totalCost) - .catch((error) => { - logger.error('❌ Failed to update rate limit cost count:', error) - }) - logger.api( - `💰 Updated rate limit cost count: +$${costInfo.totalCost.toFixed(6)}` - ) - } - } - } + queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens, + outputTokens, + cacheCreateTokens, + cacheReadTokens + }, + model, + 'claude-console-stream' + ) usageDataCaptured = true logger.api( @@ -368,33 +354,17 @@ async function handleMessagesRequest(req, res) { logger.error('❌ Failed to record Bedrock stream usage:', error) }) - // 更新时间窗口内的token计数和费用 - if (req.rateLimitInfo) { - const totalTokens = inputTokens + outputTokens - - // 更新Token计数(向后兼容) - redis - .getClient() - .incrby(req.rateLimitInfo.tokenCountKey, totalTokens) - .catch((error) => { - logger.error('❌ Failed to update rate limit token count:', error) - }) - logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`) - - // 计算并更新费用计数(新功能) - if (req.rateLimitInfo.costCountKey) { - const costInfo = pricingService.calculateCost(result.usage, result.model) - if (costInfo.totalCost > 0) { - redis - .getClient() - .incrbyfloat(req.rateLimitInfo.costCountKey, costInfo.totalCost) - .catch((error) => { - logger.error('❌ Failed to update rate limit cost count:', error) - }) - logger.api(`💰 Updated rate limit cost count: +$${costInfo.totalCost.toFixed(6)}`) - } - } - } + queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens, + outputTokens, + cacheCreateTokens: 0, + cacheReadTokens: 0 + }, + result.model, + 'bedrock-stream' + ) usageDataCaptured = true logger.api( @@ -469,35 +439,17 @@ async function handleMessagesRequest(req, res) { logger.error('❌ Failed to record CCR stream usage:', error) }) - // 更新时间窗口内的token计数和费用 - if (req.rateLimitInfo) { - const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens - - // 更新Token计数(向后兼容) - redis - .getClient() - .incrby(req.rateLimitInfo.tokenCountKey, totalTokens) - .catch((error) => { - logger.error('❌ Failed to update rate limit token count:', error) - }) - logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`) - - // 计算并更新费用计数(新功能) - if (req.rateLimitInfo.costCountKey) { - const costInfo = pricingService.calculateCost(usageData, model) - if (costInfo.totalCost > 0) { - redis - .getClient() - .incrbyfloat(req.rateLimitInfo.costCountKey, costInfo.totalCost) - .catch((error) => { - logger.error('❌ Failed to update rate limit cost count:', error) - }) - logger.api( - `💰 Updated rate limit cost count: +$${costInfo.totalCost.toFixed(6)}` - ) - } - } - } + queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens, + outputTokens, + cacheCreateTokens, + cacheReadTokens + }, + model, + 'ccr-stream' + ) usageDataCaptured = true logger.api( @@ -685,25 +637,17 @@ async function handleMessagesRequest(req, res) { responseAccountId ) - // 更新时间窗口内的token计数和费用 - if (req.rateLimitInfo) { - const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens - - // 更新Token计数(向后兼容) - await redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens) - logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`) - - // 计算并更新费用计数(新功能) - if (req.rateLimitInfo.costCountKey) { - const costInfo = pricingService.calculateCost(jsonData.usage, model) - if (costInfo.totalCost > 0) { - await redis - .getClient() - .incrbyfloat(req.rateLimitInfo.costCountKey, costInfo.totalCost) - logger.api(`💰 Updated rate limit cost count: +$${costInfo.totalCost.toFixed(6)}`) - } - } - } + await queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens, + outputTokens, + cacheCreateTokens, + cacheReadTokens + }, + model, + 'claude-non-stream' + ) usageRecorded = true logger.api( diff --git a/src/routes/geminiRoutes.js b/src/routes/geminiRoutes.js index df447fb7..532979cf 100644 --- a/src/routes/geminiRoutes.js +++ b/src/routes/geminiRoutes.js @@ -8,6 +8,7 @@ const crypto = require('crypto') const sessionHelper = require('../utils/sessionHelper') const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler') const apiKeyService = require('../services/apiKeyService') +const { updateRateLimitCounters } = require('../utils/rateLimitHelper') // const { OAuth2Client } = require('google-auth-library'); // OAuth2Client is not used in this file // 生成会话哈希 @@ -49,6 +50,31 @@ function ensureGeminiPermission(req, res) { return false } +async function applyRateLimitTracking(req, usageSummary, model, context = '') { + if (!req.rateLimitInfo) { + return + } + + const label = context ? ` (${context})` : '' + + try { + const { totalTokens, totalCost } = await updateRateLimitCounters( + req.rateLimitInfo, + usageSummary, + model + ) + + if (totalTokens > 0) { + logger.api(`📊 Updated rate limit token count${label}: +${totalTokens} tokens`) + } + if (typeof totalCost === 'number' && totalCost > 0) { + logger.api(`💰 Updated rate limit cost count${label}: +$${totalCost.toFixed(6)}`) + } + } catch (error) { + logger.error(`❌ Failed to update rate limit counters${label}:`, error) + } +} + // Gemini 消息处理端点 router.post('/messages', authenticateApiKey, async (req, res) => { const startTime = Date.now() @@ -679,6 +705,18 @@ async function handleGenerateContent(req, res) { logger.info( `📊 Recorded Gemini usage - Input: ${usage.promptTokenCount}, Output: ${usage.candidatesTokenCount}, Total: ${usage.totalTokenCount}` ) + + await applyRateLimitTracking( + req, + { + inputTokens: usage.promptTokenCount || 0, + outputTokens: usage.candidatesTokenCount || 0, + cacheCreateTokens: 0, + cacheReadTokens: 0 + }, + model, + 'gemini-non-stream' + ) } catch (error) { logger.error('Failed to record Gemini usage:', error) } @@ -935,6 +973,18 @@ async function handleStreamGenerateContent(req, res) { logger.info( `📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}` ) + + await applyRateLimitTracking( + req, + { + inputTokens: totalUsage.promptTokenCount || 0, + outputTokens: totalUsage.candidatesTokenCount || 0, + cacheCreateTokens: 0, + cacheReadTokens: 0 + }, + model, + 'gemini-stream' + ) } catch (error) { logger.error('Failed to record Gemini usage:', error) } diff --git a/src/routes/openaiClaudeRoutes.js b/src/routes/openaiClaudeRoutes.js index e1514d5b..f5db5665 100644 --- a/src/routes/openaiClaudeRoutes.js +++ b/src/routes/openaiClaudeRoutes.js @@ -15,6 +15,7 @@ const apiKeyService = require('../services/apiKeyService') const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler') const claudeCodeHeadersService = require('../services/claudeCodeHeadersService') const sessionHelper = require('../utils/sessionHelper') +const { updateRateLimitCounters } = require('../utils/rateLimitHelper') // 加载模型定价数据 let modelPricingData = {} @@ -33,6 +34,27 @@ function checkPermissions(apiKeyData, requiredPermission = 'claude') { return permissions === 'all' || permissions === requiredPermission } +function queueRateLimitUpdate(rateLimitInfo, usageSummary, model, context = '') { + if (!rateLimitInfo) { + return + } + + const label = context ? ` (${context})` : '' + + updateRateLimitCounters(rateLimitInfo, usageSummary, model) + .then(({ totalTokens, totalCost }) => { + if (totalTokens > 0) { + logger.api(`📊 Updated rate limit token count${label}: +${totalTokens} tokens`) + } + if (typeof totalCost === 'number' && totalCost > 0) { + logger.api(`💰 Updated rate limit cost count${label}: +$${totalCost.toFixed(6)}`) + } + }) + .catch((error) => { + logger.error(`❌ Failed to update rate limit counters${label}:`, error) + }) +} + // 📋 OpenAI 兼容的模型列表端点 router.get('/v1/models', authenticateApiKey, async (req, res) => { try { @@ -263,6 +285,12 @@ async function handleChatCompletion(req, res, apiKeyData) { // 记录使用统计 if (usage && usage.input_tokens !== undefined && usage.output_tokens !== undefined) { const model = usage.model || claudeRequest.model + const cacheCreateTokens = + (usage.cache_creation && typeof usage.cache_creation === 'object' + ? (usage.cache_creation.ephemeral_5m_input_tokens || 0) + + (usage.cache_creation.ephemeral_1h_input_tokens || 0) + : usage.cache_creation_input_tokens || 0) || 0 + const cacheReadTokens = usage.cache_read_input_tokens || 0 // 使用新的 recordUsageWithDetails 方法来支持详细的缓存数据 apiKeyService @@ -275,6 +303,18 @@ async function handleChatCompletion(req, res, apiKeyData) { .catch((error) => { logger.error('❌ Failed to record usage:', error) }) + + queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens: usage.input_tokens || 0, + outputTokens: usage.output_tokens || 0, + cacheCreateTokens, + cacheReadTokens + }, + model, + 'openai-claude-stream' + ) } }, // 流转换器 @@ -334,6 +374,12 @@ async function handleChatCompletion(req, res, apiKeyData) { // 记录使用统计 if (claudeData.usage) { const { usage } = claudeData + const cacheCreateTokens = + (usage.cache_creation && typeof usage.cache_creation === 'object' + ? (usage.cache_creation.ephemeral_5m_input_tokens || 0) + + (usage.cache_creation.ephemeral_1h_input_tokens || 0) + : usage.cache_creation_input_tokens || 0) || 0 + const cacheReadTokens = usage.cache_read_input_tokens || 0 // 使用新的 recordUsageWithDetails 方法来支持详细的缓存数据 apiKeyService .recordUsageWithDetails( @@ -345,6 +391,18 @@ async function handleChatCompletion(req, res, apiKeyData) { .catch((error) => { logger.error('❌ Failed to record usage:', error) }) + + queueRateLimitUpdate( + req.rateLimitInfo, + { + inputTokens: usage.input_tokens || 0, + outputTokens: usage.output_tokens || 0, + cacheCreateTokens, + cacheReadTokens + }, + claudeRequest.model, + 'openai-claude-non-stream' + ) } // 返回 OpenAI 格式响应 diff --git a/src/routes/openaiRoutes.js b/src/routes/openaiRoutes.js index 604d8c35..13776c8d 100644 --- a/src/routes/openaiRoutes.js +++ b/src/routes/openaiRoutes.js @@ -11,6 +11,7 @@ const openaiResponsesRelayService = require('../services/openaiResponsesRelaySer const apiKeyService = require('../services/apiKeyService') const crypto = require('crypto') const ProxyHelper = require('../utils/proxyHelper') +const { updateRateLimitCounters } = require('../utils/rateLimitHelper') // 创建代理 Agent(使用统一的代理工具) function createProxyAgent(proxy) { @@ -67,6 +68,31 @@ function extractCodexUsageHeaders(headers) { return hasData ? snapshot : null } +async function applyRateLimitTracking(req, usageSummary, model, context = '') { + if (!req.rateLimitInfo) { + return + } + + const label = context ? ` (${context})` : '' + + try { + const { totalTokens, totalCost } = await updateRateLimitCounters( + req.rateLimitInfo, + usageSummary, + model + ) + + if (totalTokens > 0) { + logger.api(`📊 Updated rate limit token count${label}: +${totalTokens} tokens`) + } + if (typeof totalCost === 'number' && totalCost > 0) { + logger.api(`💰 Updated rate limit cost count${label}: +$${totalCost.toFixed(6)}`) + } + } catch (error) { + logger.error(`❌ Failed to update rate limit counters${label}:`, error) + } +} + // 使用统一调度器选择 OpenAI 账户 async function getOpenAIAuthToken(apiKeyData, sessionId = null, requestedModel = null) { try { @@ -579,6 +605,18 @@ const handleResponses = async (req, res) => { logger.info( `📊 Recorded OpenAI non-stream usage - Input: ${totalInputTokens}(actual:${actualInputTokens}+cached:${cacheReadTokens}), Output: ${outputTokens}, Total: ${usageData.total_tokens || totalInputTokens + outputTokens}, Model: ${actualModel}` ) + + await applyRateLimitTracking( + req, + { + inputTokens: actualInputTokens, + outputTokens, + cacheCreateTokens: 0, + cacheReadTokens + }, + actualModel, + 'openai-non-stream' + ) } // 返回响应 @@ -700,6 +738,18 @@ const handleResponses = async (req, res) => { `📊 Recorded OpenAI usage - Input: ${totalInputTokens}(actual:${actualInputTokens}+cached:${cacheReadTokens}), Output: ${outputTokens}, Total: ${usageData.total_tokens || totalInputTokens + outputTokens}, Model: ${modelToRecord} (actual: ${actualModel}, requested: ${requestedModel})` ) usageReported = true + + await applyRateLimitTracking( + req, + { + inputTokens: actualInputTokens, + outputTokens, + cacheCreateTokens: 0, + cacheReadTokens + }, + modelToRecord, + 'openai-stream' + ) } catch (error) { logger.error('Failed to record OpenAI usage:', error) } diff --git a/src/utils/rateLimitHelper.js b/src/utils/rateLimitHelper.js new file mode 100644 index 00000000..38c38568 --- /dev/null +++ b/src/utils/rateLimitHelper.js @@ -0,0 +1,71 @@ +const redis = require('../models/redis') +const pricingService = require('../services/pricingService') +const CostCalculator = require('./costCalculator') + +function toNumber(value) { + const num = Number(value) + return Number.isFinite(num) ? num : 0 +} + +async function updateRateLimitCounters(rateLimitInfo, usageSummary, model) { + if (!rateLimitInfo) { + return { totalTokens: 0, totalCost: 0 } + } + + const client = redis.getClient() + if (!client) { + throw new Error('Redis 未连接,无法更新限流计数') + } + + const inputTokens = toNumber(usageSummary.inputTokens) + const outputTokens = toNumber(usageSummary.outputTokens) + const cacheCreateTokens = toNumber(usageSummary.cacheCreateTokens) + const cacheReadTokens = toNumber(usageSummary.cacheReadTokens) + + const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens + + if (totalTokens > 0 && rateLimitInfo.tokenCountKey) { + await client.incrby(rateLimitInfo.tokenCountKey, Math.round(totalTokens)) + } + + let totalCost = 0 + const usagePayload = { + input_tokens: inputTokens, + output_tokens: outputTokens, + cache_creation_input_tokens: cacheCreateTokens, + cache_read_input_tokens: cacheReadTokens + } + + try { + const costInfo = pricingService.calculateCost(usagePayload, model) + const { totalCost: calculatedCost } = costInfo || {} + if (typeof calculatedCost === 'number') { + totalCost = calculatedCost + } + } catch (error) { + // 忽略此处错误,后续使用备用计算 + totalCost = 0 + } + + if (totalCost === 0) { + try { + const fallback = CostCalculator.calculateCost(usagePayload, model) + const { costs } = fallback || {} + if (costs && typeof costs.total === 'number') { + totalCost = costs.total + } + } catch (error) { + totalCost = 0 + } + } + + if (totalCost > 0 && rateLimitInfo.costCountKey) { + await client.incrbyfloat(rateLimitInfo.costCountKey, totalCost) + } + + return { totalTokens, totalCost } +} + +module.exports = { + updateRateLimitCounters +}