mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 09:06:18 +00:00
fix: 优化count_tokens接口不受并发跟客户端限制
This commit is contained in:
@@ -7,6 +7,37 @@ const redis = require('../models/redis')
|
|||||||
// const { RateLimiterRedis } = require('rate-limiter-flexible') // 暂时未使用
|
// const { RateLimiterRedis } = require('rate-limiter-flexible') // 暂时未使用
|
||||||
const ClientValidator = require('../validators/clientValidator')
|
const ClientValidator = require('../validators/clientValidator')
|
||||||
|
|
||||||
|
const TOKEN_COUNT_PATHS = new Set([
|
||||||
|
'/v1/messages/count_tokens',
|
||||||
|
'/api/v1/messages/count_tokens',
|
||||||
|
'/claude/v1/messages/count_tokens',
|
||||||
|
'/droid/claude/v1/messages/count_tokens'
|
||||||
|
])
|
||||||
|
|
||||||
|
function normalizeRequestPath(value) {
|
||||||
|
if (!value) {
|
||||||
|
return '/'
|
||||||
|
}
|
||||||
|
const lower = value.split('?')[0].toLowerCase()
|
||||||
|
const collapsed = lower.replace(/\/{2,}/g, '/')
|
||||||
|
if (collapsed.length > 1 && collapsed.endsWith('/')) {
|
||||||
|
return collapsed.slice(0, -1)
|
||||||
|
}
|
||||||
|
return collapsed || '/'
|
||||||
|
}
|
||||||
|
|
||||||
|
function isTokenCountRequest(req) {
|
||||||
|
const combined = normalizeRequestPath(`${req.baseUrl || ''}${req.path || ''}`)
|
||||||
|
if (TOKEN_COUNT_PATHS.has(combined)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
const original = normalizeRequestPath(req.originalUrl || '')
|
||||||
|
if (TOKEN_COUNT_PATHS.has(original)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// 🔑 API Key验证中间件(优化版)
|
// 🔑 API Key验证中间件(优化版)
|
||||||
const authenticateApiKey = async (req, res, next) => {
|
const authenticateApiKey = async (req, res, next) => {
|
||||||
const startTime = Date.now()
|
const startTime = Date.now()
|
||||||
@@ -49,8 +80,11 @@ const authenticateApiKey = async (req, res, next) => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const skipKeyRestrictions = isTokenCountRequest(req)
|
||||||
|
|
||||||
// 🔒 检查客户端限制(使用新的验证器)
|
// 🔒 检查客户端限制(使用新的验证器)
|
||||||
if (
|
if (
|
||||||
|
!skipKeyRestrictions &&
|
||||||
validation.keyData.enableClientRestriction &&
|
validation.keyData.enableClientRestriction &&
|
||||||
validation.keyData.allowedClients?.length > 0
|
validation.keyData.allowedClients?.length > 0
|
||||||
) {
|
) {
|
||||||
@@ -81,7 +115,7 @@ const authenticateApiKey = async (req, res, next) => {
|
|||||||
|
|
||||||
// 检查并发限制
|
// 检查并发限制
|
||||||
const concurrencyLimit = validation.keyData.concurrencyLimit || 0
|
const concurrencyLimit = validation.keyData.concurrencyLimit || 0
|
||||||
if (concurrencyLimit > 0) {
|
if (!skipKeyRestrictions && concurrencyLimit > 0) {
|
||||||
const concurrencyConfig = config.concurrency || {}
|
const concurrencyConfig = config.concurrency || {}
|
||||||
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
|
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
|
||||||
const rawRenewInterval =
|
const rawRenewInterval =
|
||||||
|
|||||||
@@ -60,6 +60,49 @@ router.post('/claude/v1/messages', authenticateApiKey, async (req, res) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
router.post('/claude/v1/messages/count_tokens', authenticateApiKey, async (req, res) => {
|
||||||
|
try {
|
||||||
|
const requestBody = { ...req.body }
|
||||||
|
if ('stream' in requestBody) {
|
||||||
|
delete requestBody.stream
|
||||||
|
}
|
||||||
|
const sessionHash = sessionHelper.generateSessionHash(requestBody)
|
||||||
|
|
||||||
|
if (!hasDroidPermission(req.apiKey)) {
|
||||||
|
logger.security(
|
||||||
|
`🚫 API Key ${req.apiKey?.id || 'unknown'} 缺少 Droid 权限,拒绝访问 ${req.originalUrl}`
|
||||||
|
)
|
||||||
|
return res.status(403).json({
|
||||||
|
error: 'permission_denied',
|
||||||
|
message: '此 API Key 未启用 Droid 权限'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await droidRelayService.relayRequest(
|
||||||
|
requestBody,
|
||||||
|
req.apiKey,
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
req.headers,
|
||||||
|
{
|
||||||
|
endpointType: 'anthropic',
|
||||||
|
sessionHash,
|
||||||
|
customPath: '/a/v1/messages/count_tokens',
|
||||||
|
skipUsageRecord: true,
|
||||||
|
disableStreaming: true
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
res.status(result.statusCode).set(result.headers).send(result.body)
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Droid Claude count_tokens relay error:', error)
|
||||||
|
res.status(500).json({
|
||||||
|
error: 'internal_server_error',
|
||||||
|
message: error.message
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// OpenAI 端点 - /v1/responses
|
// OpenAI 端点 - /v1/responses
|
||||||
router.post('/openai/v1/responses', authenticateApiKey, async (req, res) => {
|
router.post('/openai/v1/responses', authenticateApiKey, async (req, res) => {
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -146,7 +146,13 @@ class DroidRelayService {
|
|||||||
clientHeaders,
|
clientHeaders,
|
||||||
options = {}
|
options = {}
|
||||||
) {
|
) {
|
||||||
const { endpointType = 'anthropic', sessionHash = null } = options
|
const {
|
||||||
|
endpointType = 'anthropic',
|
||||||
|
sessionHash = null,
|
||||||
|
customPath = null,
|
||||||
|
skipUsageRecord = false,
|
||||||
|
disableStreaming = false
|
||||||
|
} = options
|
||||||
const keyInfo = apiKeyData || {}
|
const keyInfo = apiKeyData || {}
|
||||||
const normalizedEndpoint = this._normalizeEndpointType(endpointType)
|
const normalizedEndpoint = this._normalizeEndpointType(endpointType)
|
||||||
|
|
||||||
@@ -179,8 +185,12 @@ class DroidRelayService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取 Factory.ai API URL
|
// 获取 Factory.ai API URL
|
||||||
const endpoint = this.endpoints[normalizedEndpoint]
|
let endpointPath = this.endpoints[normalizedEndpoint]
|
||||||
const apiUrl = `${this.factoryApiBaseUrl}${endpoint}`
|
if (typeof customPath === 'string' && customPath.trim()) {
|
||||||
|
endpointPath = customPath.startsWith('/') ? customPath : `/${customPath}`
|
||||||
|
}
|
||||||
|
|
||||||
|
const apiUrl = `${this.factoryApiBaseUrl}${endpointPath}`
|
||||||
|
|
||||||
logger.info(`🌐 Forwarding to Factory.ai: ${apiUrl}`)
|
logger.info(`🌐 Forwarding to Factory.ai: ${apiUrl}`)
|
||||||
|
|
||||||
@@ -207,10 +217,12 @@ class DroidRelayService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理请求体(注入 system prompt 等)
|
// 处理请求体(注入 system prompt 等)
|
||||||
const processedBody = this._processRequestBody(requestBody, normalizedEndpoint)
|
const processedBody = this._processRequestBody(requestBody, normalizedEndpoint, {
|
||||||
|
disableStreaming
|
||||||
|
})
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
const isStreaming = processedBody.stream !== false
|
const isStreaming = disableStreaming ? false : processedBody.stream !== false
|
||||||
|
|
||||||
// 根据是否流式选择不同的处理方式
|
// 根据是否流式选择不同的处理方式
|
||||||
if (isStreaming) {
|
if (isStreaming) {
|
||||||
@@ -225,7 +237,8 @@ class DroidRelayService {
|
|||||||
account,
|
account,
|
||||||
keyInfo,
|
keyInfo,
|
||||||
requestBody,
|
requestBody,
|
||||||
normalizedEndpoint
|
normalizedEndpoint,
|
||||||
|
skipUsageRecord
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// 非流式响应:使用 axios
|
// 非流式响应:使用 axios
|
||||||
@@ -253,7 +266,8 @@ class DroidRelayService {
|
|||||||
keyInfo,
|
keyInfo,
|
||||||
requestBody,
|
requestBody,
|
||||||
clientRequest,
|
clientRequest,
|
||||||
normalizedEndpoint
|
normalizedEndpoint,
|
||||||
|
skipUsageRecord
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -298,7 +312,8 @@ class DroidRelayService {
|
|||||||
account,
|
account,
|
||||||
apiKeyData,
|
apiKeyData,
|
||||||
requestBody,
|
requestBody,
|
||||||
endpointType
|
endpointType,
|
||||||
|
skipUsageRecord = false
|
||||||
) {
|
) {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
const url = new URL(apiUrl)
|
const url = new URL(apiUrl)
|
||||||
@@ -449,28 +464,34 @@ class DroidRelayService {
|
|||||||
clientResponse.end()
|
clientResponse.end()
|
||||||
|
|
||||||
// 记录 usage 数据
|
// 记录 usage 数据
|
||||||
const normalizedUsage = await this._recordUsageFromStreamData(
|
if (!skipUsageRecord) {
|
||||||
currentUsageData,
|
const normalizedUsage = await this._recordUsageFromStreamData(
|
||||||
apiKeyData,
|
currentUsageData,
|
||||||
account,
|
apiKeyData,
|
||||||
model
|
account,
|
||||||
)
|
model
|
||||||
|
)
|
||||||
|
|
||||||
const usageSummary = {
|
const usageSummary = {
|
||||||
inputTokens: normalizedUsage.input_tokens || 0,
|
inputTokens: normalizedUsage.input_tokens || 0,
|
||||||
outputTokens: normalizedUsage.output_tokens || 0,
|
outputTokens: normalizedUsage.output_tokens || 0,
|
||||||
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
|
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
|
||||||
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
|
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
|
||||||
|
}
|
||||||
|
|
||||||
|
await this._applyRateLimitTracking(
|
||||||
|
clientRequest?.rateLimitInfo,
|
||||||
|
usageSummary,
|
||||||
|
model,
|
||||||
|
' [stream]'
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success(`✅ Droid stream completed - Account: ${account.name}`)
|
||||||
|
} else {
|
||||||
|
logger.success(
|
||||||
|
`✅ Droid stream completed - Account: ${account.name}, usage recording skipped`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
await this._applyRateLimitTracking(
|
|
||||||
clientRequest?.rateLimitInfo,
|
|
||||||
usageSummary,
|
|
||||||
model,
|
|
||||||
' [stream]'
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.success(`✅ Droid stream completed - Account: ${account.name}`)
|
|
||||||
resolveOnce({ statusCode: 200, streaming: true })
|
resolveOnce({ statusCode: 200, streaming: true })
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -801,11 +822,15 @@ class DroidRelayService {
|
|||||||
/**
|
/**
|
||||||
* 处理请求体(注入 system prompt 等)
|
* 处理请求体(注入 system prompt 等)
|
||||||
*/
|
*/
|
||||||
_processRequestBody(requestBody, endpointType) {
|
_processRequestBody(requestBody, endpointType, options = {}) {
|
||||||
|
const { disableStreaming = false } = options
|
||||||
const processedBody = { ...requestBody }
|
const processedBody = { ...requestBody }
|
||||||
|
|
||||||
// 确保 stream 字段存在
|
if (disableStreaming) {
|
||||||
if (processedBody.stream === undefined) {
|
if ('stream' in processedBody) {
|
||||||
|
delete processedBody.stream
|
||||||
|
}
|
||||||
|
} else if (processedBody.stream === undefined) {
|
||||||
processedBody.stream = true
|
processedBody.stream = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -896,7 +921,8 @@ class DroidRelayService {
|
|||||||
apiKeyData,
|
apiKeyData,
|
||||||
requestBody,
|
requestBody,
|
||||||
clientRequest,
|
clientRequest,
|
||||||
endpointType
|
endpointType,
|
||||||
|
skipUsageRecord = false
|
||||||
) {
|
) {
|
||||||
const { data } = response
|
const { data } = response
|
||||||
|
|
||||||
@@ -906,26 +932,35 @@ class DroidRelayService {
|
|||||||
const model = requestBody.model || 'unknown'
|
const model = requestBody.model || 'unknown'
|
||||||
|
|
||||||
const normalizedUsage = this._normalizeUsageSnapshot(usage)
|
const normalizedUsage = this._normalizeUsageSnapshot(usage)
|
||||||
await this._recordUsage(apiKeyData, account, model, normalizedUsage)
|
|
||||||
|
|
||||||
const totalTokens = this._getTotalTokens(normalizedUsage)
|
if (!skipUsageRecord) {
|
||||||
|
await this._recordUsage(apiKeyData, account, model, normalizedUsage)
|
||||||
|
|
||||||
const usageSummary = {
|
const totalTokens = this._getTotalTokens(normalizedUsage)
|
||||||
inputTokens: normalizedUsage.input_tokens || 0,
|
|
||||||
outputTokens: normalizedUsage.output_tokens || 0,
|
const usageSummary = {
|
||||||
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
|
inputTokens: normalizedUsage.input_tokens || 0,
|
||||||
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
|
outputTokens: normalizedUsage.output_tokens || 0,
|
||||||
|
cacheCreateTokens: normalizedUsage.cache_creation_input_tokens || 0,
|
||||||
|
cacheReadTokens: normalizedUsage.cache_read_input_tokens || 0
|
||||||
|
}
|
||||||
|
|
||||||
|
await this._applyRateLimitTracking(
|
||||||
|
clientRequest?.rateLimitInfo,
|
||||||
|
usageSummary,
|
||||||
|
model,
|
||||||
|
endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]'
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success(
|
||||||
|
`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
logger.success(
|
||||||
|
`✅ Droid request completed - Account: ${account.name}, usage recording skipped`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
await this._applyRateLimitTracking(
|
|
||||||
clientRequest?.rateLimitInfo,
|
|
||||||
usageSummary,
|
|
||||||
model,
|
|
||||||
endpointType === 'anthropic' ? ' [anthropic]' : ' [openai]'
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.success(`✅ Droid request completed - Account: ${account.name}, Tokens: ${totalTokens}`)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
|||||||
Reference in New Issue
Block a user