fix: 优化count_tokens接口不受并发跟客户端限制

This commit is contained in:
shaw
2025-10-10 17:16:10 +08:00
parent 5165d6c536
commit 66fe3cf74a
3 changed files with 161 additions and 49 deletions

View File

@@ -7,6 +7,37 @@ const redis = require('../models/redis')
// const { RateLimiterRedis } = require('rate-limiter-flexible') // 暂时未使用
const ClientValidator = require('../validators/clientValidator')
const TOKEN_COUNT_PATHS = new Set([
'/v1/messages/count_tokens',
'/api/v1/messages/count_tokens',
'/claude/v1/messages/count_tokens',
'/droid/claude/v1/messages/count_tokens'
])
function normalizeRequestPath(value) {
if (!value) {
return '/'
}
const lower = value.split('?')[0].toLowerCase()
const collapsed = lower.replace(/\/{2,}/g, '/')
if (collapsed.length > 1 && collapsed.endsWith('/')) {
return collapsed.slice(0, -1)
}
return collapsed || '/'
}
function isTokenCountRequest(req) {
const combined = normalizeRequestPath(`${req.baseUrl || ''}${req.path || ''}`)
if (TOKEN_COUNT_PATHS.has(combined)) {
return true
}
const original = normalizeRequestPath(req.originalUrl || '')
if (TOKEN_COUNT_PATHS.has(original)) {
return true
}
return false
}
// 🔑 API Key验证中间件优化版
const authenticateApiKey = async (req, res, next) => {
const startTime = Date.now()
@@ -49,8 +80,11 @@ const authenticateApiKey = async (req, res, next) => {
})
}
const skipKeyRestrictions = isTokenCountRequest(req)
// 🔒 检查客户端限制(使用新的验证器)
if (
!skipKeyRestrictions &&
validation.keyData.enableClientRestriction &&
validation.keyData.allowedClients?.length > 0
) {
@@ -81,7 +115,7 @@ const authenticateApiKey = async (req, res, next) => {
// 检查并发限制
const concurrencyLimit = validation.keyData.concurrencyLimit || 0
if (concurrencyLimit > 0) {
if (!skipKeyRestrictions && concurrencyLimit > 0) {
const concurrencyConfig = config.concurrency || {}
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
const rawRenewInterval =

View File

@@ -60,6 +60,49 @@ router.post('/claude/v1/messages', authenticateApiKey, async (req, res) => {
}
})
router.post('/claude/v1/messages/count_tokens', authenticateApiKey, async (req, res) => {
try {
const requestBody = { ...req.body }
if ('stream' in requestBody) {
delete requestBody.stream
}
const sessionHash = sessionHelper.generateSessionHash(requestBody)
if (!hasDroidPermission(req.apiKey)) {
logger.security(
`🚫 API Key ${req.apiKey?.id || 'unknown'} 缺少 Droid 权限,拒绝访问 ${req.originalUrl}`
)
return res.status(403).json({
error: 'permission_denied',
message: '此 API Key 未启用 Droid 权限'
})
}
const result = await droidRelayService.relayRequest(
requestBody,
req.apiKey,
req,
res,
req.headers,
{
endpointType: 'anthropic',
sessionHash,
customPath: '/a/v1/messages/count_tokens',
skipUsageRecord: true,
disableStreaming: true
}
)
res.status(result.statusCode).set(result.headers).send(result.body)
} catch (error) {
logger.error('Droid Claude count_tokens relay error:', error)
res.status(500).json({
error: 'internal_server_error',
message: error.message
})
}
})
// OpenAI 端点 - /v1/responses
router.post('/openai/v1/responses', authenticateApiKey, async (req, res) => {
try {

View File

@@ -146,7 +146,13 @@ class DroidRelayService {
clientHeaders,
options = {}
) {
const { endpointType = 'anthropic', sessionHash = null } = options
const {
endpointType = 'anthropic',
sessionHash = null,
customPath = null,
skipUsageRecord = false,
disableStreaming = false
} = options
const keyInfo = apiKeyData || {}
const normalizedEndpoint = this._normalizeEndpointType(endpointType)
@@ -179,8 +185,12 @@ class DroidRelayService {
}
// 获取 Factory.ai API URL
const endpoint = this.endpoints[normalizedEndpoint]
const apiUrl = `${this.factoryApiBaseUrl}${endpoint}`
let endpointPath = this.endpoints[normalizedEndpoint]
if (typeof customPath === 'string' && customPath.trim()) {
endpointPath = customPath.startsWith('/') ? customPath : `/${customPath}`
}
const apiUrl = `${this.factoryApiBaseUrl}${endpointPath}`
logger.info(`🌐 Forwarding to Factory.ai: ${apiUrl}`)
@@ -207,10 +217,12 @@ class DroidRelayService {
}
// 处理请求体(注入 system prompt 等)
const processedBody = this._processRequestBody(requestBody, normalizedEndpoint)
const processedBody = this._processRequestBody(requestBody, normalizedEndpoint, {
disableStreaming
})
// 发送请求
const isStreaming = processedBody.stream !== false
const isStreaming = disableStreaming ? false : processedBody.stream !== false
// 根据是否流式选择不同的处理方式
if (isStreaming) {
@@ -225,7 +237,8 @@ class DroidRelayService {
account,
keyInfo,
requestBody,
normalizedEndpoint
normalizedEndpoint,
skipUsageRecord
)
} else {
// 非流式响应:使用 axios
@@ -253,7 +266,8 @@ class DroidRelayService {
keyInfo,
requestBody,
clientRequest,
normalizedEndpoint
normalizedEndpoint,
skipUsageRecord
)
}
} catch (error) {
@@ -298,7 +312,8 @@ class DroidRelayService {
account,
apiKeyData,
requestBody,
endpointType
endpointType,
skipUsageRecord = false
) {
return new Promise((resolve, reject) => {
const url = new URL(apiUrl)
@@ -449,28 +464,34 @@ class DroidRelayService {
clientResponse.end()
// 记录 usage 数据
const normalizedUsage = await this._recordUsageFromStreamData(
currentUsageData,
apiKeyData,
account,
model
)
if (!skipUsageRecord) {
const normalizedUsage = await this._recordUsageFromStreamData(
currentUsageData,
apiKeyData,
account,
model
)
const usageSummary = {
inputTokens: normalizedUsage.input_tokens || 0,
outputTokens: normalizedUsage.output_tokens || 0,
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
const usageSummary = {
inputTokens: normalizedUsage.input_tokens || 0,
outputTokens: normalizedUsage.output_tokens || 0,
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
}
await this._applyRateLimitTracking(
clientRequest?.rateLimitInfo,
usageSummary,
model,
' [stream]'
)
logger.success(`✅ Droid stream completed - Account: ${account.name}`)
} else {
logger.success(
`✅ Droid stream completed - Account: ${account.name}, usage recording skipped`
)
}
await this._applyRateLimitTracking(
clientRequest?.rateLimitInfo,
usageSummary,
model,
' [stream]'
)
logger.success(`✅ Droid stream completed - Account: ${account.name}`)
resolveOnce({ statusCode: 200, streaming: true })
})
@@ -801,11 +822,15 @@ class DroidRelayService {
/**
* 处理请求体(注入 system prompt 等)
*/
_processRequestBody(requestBody, endpointType) {
_processRequestBody(requestBody, endpointType, options = {}) {
const { disableStreaming = false } = options
const processedBody = { ...requestBody }
// 确保 stream 字段存在
if (processedBody.stream === undefined) {
if (disableStreaming) {
if ('stream' in processedBody) {
delete processedBody.stream
}
} else if (processedBody.stream === undefined) {
processedBody.stream = true
}
@@ -896,7 +921,8 @@ class DroidRelayService {
apiKeyData,
requestBody,
clientRequest,
endpointType
endpointType,
skipUsageRecord = false
) {
const { data } = response
@@ -906,26 +932,35 @@ class DroidRelayService {
const model = requestBody.model || 'unknown'
const normalizedUsage = this._normalizeUsageSnapshot(usage)
await this._recordUsage(apiKeyData, account, model, normalizedUsage)
const totalTokens = this._getTotalTokens(normalizedUsage)
if (!skipUsageRecord) {
await this._recordUsage(apiKeyData, account, model, normalizedUsage)
const usageSummary = {
inputTokens: normalizedUsage.input_tokens || 0,
outputTokens: normalizedUsage.output_tokens || 0,
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
const totalTokens = this._getTotalTokens(normalizedUsage)
const usageSummary = {
inputTokens: normalizedUsage.input_tokens || 0,
outputTokens: normalizedUsage.output_tokens || 0,
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
}
await this._applyRateLimitTracking(
clientRequest?.rateLimitInfo,
usageSummary,
model,
endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]'
)
logger.success(
`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`
)
} else {
logger.success(
`✅ Droid request completed - Account: ${account.name}, usage recording skipped`
)
}
await this._applyRateLimitTracking(
clientRequest?.rateLimitInfo,
usageSummary,
model,
endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]'
)
logger.success(`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`)
return {
statusCode: 200,
headers: { 'Content-Type': 'application/json' },