fix: 优化count_tokens接口不受并发跟客户端限制

This commit is contained in:
shaw
2025-10-10 17:16:10 +08:00
parent 5165d6c536
commit 66fe3cf74a
3 changed files with 161 additions and 49 deletions

View File

@@ -7,6 +7,37 @@ const redis = require('../models/redis')
// const { RateLimiterRedis } = require('rate-limiter-flexible') // 暂时未使用 // const { RateLimiterRedis } = require('rate-limiter-flexible') // 暂时未使用
const ClientValidator = require('../validators/clientValidator') const ClientValidator = require('../validators/clientValidator')
const TOKEN_COUNT_PATHS = new Set([
'/v1/messages/count_tokens',
'/api/v1/messages/count_tokens',
'/claude/v1/messages/count_tokens',
'/droid/claude/v1/messages/count_tokens'
])
function normalizeRequestPath(value) {
if (!value) {
return '/'
}
const lower = value.split('?')[0].toLowerCase()
const collapsed = lower.replace(/\/{2,}/g, '/')
if (collapsed.length > 1 && collapsed.endsWith('/')) {
return collapsed.slice(0, -1)
}
return collapsed || '/'
}
function isTokenCountRequest(req) {
const combined = normalizeRequestPath(`${req.baseUrl || ''}${req.path || ''}`)
if (TOKEN_COUNT_PATHS.has(combined)) {
return true
}
const original = normalizeRequestPath(req.originalUrl || '')
if (TOKEN_COUNT_PATHS.has(original)) {
return true
}
return false
}
// 🔑 API Key验证中间件优化版 // 🔑 API Key验证中间件优化版
const authenticateApiKey = async (req, res, next) => { const authenticateApiKey = async (req, res, next) => {
const startTime = Date.now() const startTime = Date.now()
@@ -49,8 +80,11 @@ const authenticateApiKey = async (req, res, next) => {
}) })
} }
const skipKeyRestrictions = isTokenCountRequest(req)
// 🔒 检查客户端限制(使用新的验证器) // 🔒 检查客户端限制(使用新的验证器)
if ( if (
!skipKeyRestrictions &&
validation.keyData.enableClientRestriction && validation.keyData.enableClientRestriction &&
validation.keyData.allowedClients?.length > 0 validation.keyData.allowedClients?.length > 0
) { ) {
@@ -81,7 +115,7 @@ const authenticateApiKey = async (req, res, next) => {
// 检查并发限制 // 检查并发限制
const concurrencyLimit = validation.keyData.concurrencyLimit || 0 const concurrencyLimit = validation.keyData.concurrencyLimit || 0
if (concurrencyLimit > 0) { if (!skipKeyRestrictions && concurrencyLimit > 0) {
const concurrencyConfig = config.concurrency || {} const concurrencyConfig = config.concurrency || {}
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30) const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
const rawRenewInterval = const rawRenewInterval =

View File

@@ -60,6 +60,49 @@ router.post('/claude/v1/messages', authenticateApiKey, async (req, res) => {
} }
}) })
router.post('/claude/v1/messages/count_tokens', authenticateApiKey, async (req, res) => {
try {
const requestBody = { ...req.body }
if ('stream' in requestBody) {
delete requestBody.stream
}
const sessionHash = sessionHelper.generateSessionHash(requestBody)
if (!hasDroidPermission(req.apiKey)) {
logger.security(
`🚫 API Key ${req.apiKey?.id || 'unknown'} 缺少 Droid 权限,拒绝访问 ${req.originalUrl}`
)
return res.status(403).json({
error: 'permission_denied',
message: '此 API Key 未启用 Droid 权限'
})
}
const result = await droidRelayService.relayRequest(
requestBody,
req.apiKey,
req,
res,
req.headers,
{
endpointType: 'anthropic',
sessionHash,
customPath: '/a/v1/messages/count_tokens',
skipUsageRecord: true,
disableStreaming: true
}
)
res.status(result.statusCode).set(result.headers).send(result.body)
} catch (error) {
logger.error('Droid Claude count_tokens relay error:', error)
res.status(500).json({
error: 'internal_server_error',
message: error.message
})
}
})
// OpenAI 端点 - /v1/responses // OpenAI 端点 - /v1/responses
router.post('/openai/v1/responses', authenticateApiKey, async (req, res) => { router.post('/openai/v1/responses', authenticateApiKey, async (req, res) => {
try { try {

View File

@@ -146,7 +146,13 @@ class DroidRelayService {
clientHeaders, clientHeaders,
options = {} options = {}
) { ) {
const { endpointType = 'anthropic', sessionHash = null } = options const {
endpointType = 'anthropic',
sessionHash = null,
customPath = null,
skipUsageRecord = false,
disableStreaming = false
} = options
const keyInfo = apiKeyData || {} const keyInfo = apiKeyData || {}
const normalizedEndpoint = this._normalizeEndpointType(endpointType) const normalizedEndpoint = this._normalizeEndpointType(endpointType)
@@ -179,8 +185,12 @@ class DroidRelayService {
} }
// 获取 Factory.ai API URL // 获取 Factory.ai API URL
const endpoint = this.endpoints[normalizedEndpoint] let endpointPath = this.endpoints[normalizedEndpoint]
const apiUrl = `${this.factoryApiBaseUrl}${endpoint}` if (typeof customPath === 'string' && customPath.trim()) {
endpointPath = customPath.startsWith('/') ? customPath : `/${customPath}`
}
const apiUrl = `${this.factoryApiBaseUrl}${endpointPath}`
logger.info(`🌐 Forwarding to Factory.ai: ${apiUrl}`) logger.info(`🌐 Forwarding to Factory.ai: ${apiUrl}`)
@@ -207,10 +217,12 @@ class DroidRelayService {
} }
// 处理请求体(注入 system prompt 等) // 处理请求体(注入 system prompt 等)
const processedBody = this._processRequestBody(requestBody, normalizedEndpoint) const processedBody = this._processRequestBody(requestBody, normalizedEndpoint, {
disableStreaming
})
// 发送请求 // 发送请求
const isStreaming = processedBody.stream !== false const isStreaming = disableStreaming ? false : processedBody.stream !== false
// 根据是否流式选择不同的处理方式 // 根据是否流式选择不同的处理方式
if (isStreaming) { if (isStreaming) {
@@ -225,7 +237,8 @@ class DroidRelayService {
account, account,
keyInfo, keyInfo,
requestBody, requestBody,
normalizedEndpoint normalizedEndpoint,
skipUsageRecord
) )
} else { } else {
// 非流式响应:使用 axios // 非流式响应:使用 axios
@@ -253,7 +266,8 @@ class DroidRelayService {
keyInfo, keyInfo,
requestBody, requestBody,
clientRequest, clientRequest,
normalizedEndpoint normalizedEndpoint,
skipUsageRecord
) )
} }
} catch (error) { } catch (error) {
@@ -298,7 +312,8 @@ class DroidRelayService {
account, account,
apiKeyData, apiKeyData,
requestBody, requestBody,
endpointType endpointType,
skipUsageRecord = false
) { ) {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const url = new URL(apiUrl) const url = new URL(apiUrl)
@@ -449,6 +464,7 @@ class DroidRelayService {
clientResponse.end() clientResponse.end()
// 记录 usage 数据 // 记录 usage 数据
if (!skipUsageRecord) {
const normalizedUsage = await this._recordUsageFromStreamData( const normalizedUsage = await this._recordUsageFromStreamData(
currentUsageData, currentUsageData,
apiKeyData, apiKeyData,
@@ -471,6 +487,11 @@ class DroidRelayService {
) )
logger.success(`✅ Droid stream completed - Account: ${account.name}`) logger.success(`✅ Droid stream completed - Account: ${account.name}`)
} else {
logger.success(
`✅ Droid stream completed - Account: ${account.name}, usage recording skipped`
)
}
resolveOnce({ statusCode: 200, streaming: true }) resolveOnce({ statusCode: 200, streaming: true })
}) })
@@ -801,11 +822,15 @@ class DroidRelayService {
/** /**
* 处理请求体(注入 system prompt 等) * 处理请求体(注入 system prompt 等)
*/ */
_processRequestBody(requestBody, endpointType) { _processRequestBody(requestBody, endpointType, options = {}) {
const { disableStreaming = false } = options
const processedBody = { ...requestBody } const processedBody = { ...requestBody }
// 确保 stream 字段存在 if (disableStreaming) {
if (processedBody.stream === undefined) { if ('stream' in processedBody) {
delete processedBody.stream
}
} else if (processedBody.stream === undefined) {
processedBody.stream = true processedBody.stream = true
} }
@@ -896,7 +921,8 @@ class DroidRelayService {
apiKeyData, apiKeyData,
requestBody, requestBody,
clientRequest, clientRequest,
endpointType endpointType,
skipUsageRecord = false
) { ) {
const { data } = response const { data } = response
@@ -906,6 +932,8 @@ class DroidRelayService {
const model = requestBody.model || 'unknown' const model = requestBody.model || 'unknown'
const normalizedUsage = this._normalizeUsageSnapshot(usage) const normalizedUsage = this._normalizeUsageSnapshot(usage)
if (!skipUsageRecord) {
await this._recordUsage(apiKeyData, account, model, normalizedUsage) await this._recordUsage(apiKeyData, account, model, normalizedUsage)
const totalTokens = this._getTotalTokens(normalizedUsage) const totalTokens = this._getTotalTokens(normalizedUsage)
@@ -924,7 +952,14 @@ class DroidRelayService {
endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]' endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]'
) )
logger.success(`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`) logger.success(
`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`
)
} else {
logger.success(
`✅ Droid request completed - Account: ${account.name}, usage recording skipped`
)
}
return { return {
statusCode: 200, statusCode: 200,