mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 00:53:33 +00:00
fix: 优化并发限制数的控制逻辑
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
const { v4: uuidv4 } = require('uuid')
|
||||
const config = require('../../config/config')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const userService = require('../services/userService')
|
||||
const logger = require('../utils/logger')
|
||||
@@ -80,14 +82,33 @@ const authenticateApiKey = async (req, res, next) => {
|
||||
// 检查并发限制
|
||||
const concurrencyLimit = validation.keyData.concurrencyLimit || 0
|
||||
if (concurrencyLimit > 0) {
|
||||
const currentConcurrency = await redis.incrConcurrency(validation.keyData.id)
|
||||
const concurrencyConfig = config.concurrency || {}
|
||||
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
|
||||
const rawRenewInterval =
|
||||
typeof concurrencyConfig.renewIntervalSeconds === 'number'
|
||||
? concurrencyConfig.renewIntervalSeconds
|
||||
: 60
|
||||
let renewIntervalSeconds = rawRenewInterval
|
||||
if (renewIntervalSeconds > 0) {
|
||||
const maxSafeRenew = Math.max(leaseSeconds - 5, 15)
|
||||
renewIntervalSeconds = Math.min(Math.max(renewIntervalSeconds, 15), maxSafeRenew)
|
||||
} else {
|
||||
renewIntervalSeconds = 0
|
||||
}
|
||||
const requestId = uuidv4()
|
||||
|
||||
const currentConcurrency = await redis.incrConcurrency(
|
||||
validation.keyData.id,
|
||||
requestId,
|
||||
leaseSeconds
|
||||
)
|
||||
logger.api(
|
||||
`📈 Incremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), current: ${currentConcurrency}, limit: ${concurrencyLimit}`
|
||||
)
|
||||
|
||||
if (currentConcurrency > concurrencyLimit) {
|
||||
// 如果超过限制,立即减少计数
|
||||
await redis.decrConcurrency(validation.keyData.id)
|
||||
await redis.decrConcurrency(validation.keyData.id, requestId)
|
||||
logger.security(
|
||||
`🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${
|
||||
validation.keyData.name
|
||||
@@ -101,14 +122,39 @@ const authenticateApiKey = async (req, res, next) => {
|
||||
})
|
||||
}
|
||||
|
||||
const renewIntervalMs =
|
||||
renewIntervalSeconds > 0 ? Math.max(renewIntervalSeconds * 1000, 15000) : 0
|
||||
|
||||
// 使用标志位确保只减少一次
|
||||
let concurrencyDecremented = false
|
||||
let leaseRenewInterval = null
|
||||
|
||||
if (renewIntervalMs > 0) {
|
||||
leaseRenewInterval = setInterval(() => {
|
||||
redis
|
||||
.refreshConcurrencyLease(validation.keyData.id, requestId, leaseSeconds)
|
||||
.catch((error) => {
|
||||
logger.error(
|
||||
`Failed to refresh concurrency lease for key ${validation.keyData.id}:`,
|
||||
error
|
||||
)
|
||||
})
|
||||
}, renewIntervalMs)
|
||||
|
||||
if (typeof leaseRenewInterval.unref === 'function') {
|
||||
leaseRenewInterval.unref()
|
||||
}
|
||||
}
|
||||
|
||||
const decrementConcurrency = async () => {
|
||||
if (!concurrencyDecremented) {
|
||||
concurrencyDecremented = true
|
||||
if (leaseRenewInterval) {
|
||||
clearInterval(leaseRenewInterval)
|
||||
leaseRenewInterval = null
|
||||
}
|
||||
try {
|
||||
const newCount = await redis.decrConcurrency(validation.keyData.id)
|
||||
const newCount = await redis.decrConcurrency(validation.keyData.id, requestId)
|
||||
logger.api(
|
||||
`📉 Decremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), new count: ${newCount}`
|
||||
)
|
||||
@@ -147,6 +193,7 @@ const authenticateApiKey = async (req, res, next) => {
|
||||
req.concurrencyInfo = {
|
||||
apiKeyId: validation.keyData.id,
|
||||
apiKeyName: validation.keyData.name,
|
||||
requestId,
|
||||
decrementConcurrency
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1538,18 +1538,55 @@ class RedisClient {
|
||||
}
|
||||
}
|
||||
|
||||
// 增加并发计数
|
||||
async incrConcurrency(apiKeyId) {
|
||||
// 获取并发配置
|
||||
_getConcurrencyConfig() {
|
||||
const defaults = {
|
||||
leaseSeconds: 900,
|
||||
cleanupGraceSeconds: 30
|
||||
}
|
||||
return {
|
||||
...defaults,
|
||||
...(config.concurrency || {})
|
||||
}
|
||||
}
|
||||
|
||||
// 增加并发计数(基于租约的有序集合)
|
||||
async incrConcurrency(apiKeyId, requestId, leaseSeconds = null) {
|
||||
if (!requestId) {
|
||||
throw new Error('Request ID is required for concurrency tracking')
|
||||
}
|
||||
|
||||
try {
|
||||
const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } =
|
||||
this._getConcurrencyConfig()
|
||||
const lease = leaseSeconds || defaultLeaseSeconds
|
||||
const key = `concurrency:${apiKeyId}`
|
||||
const count = await this.client.incr(key)
|
||||
const now = Date.now()
|
||||
const expireAt = now + lease * 1000
|
||||
const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000)
|
||||
|
||||
// 设置过期时间为180秒(3分钟),防止计数器永远不清零
|
||||
// 正常情况下请求会在完成时主动减少计数,这只是一个安全保障
|
||||
// 180秒足够支持较长的流式请求
|
||||
await this.client.expire(key, 180)
|
||||
const luaScript = `
|
||||
local key = KEYS[1]
|
||||
local member = ARGV[1]
|
||||
local expireAt = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local ttl = tonumber(ARGV[4])
|
||||
|
||||
logger.database(`🔢 Incremented concurrency for key ${apiKeyId}: ${count}`)
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
|
||||
redis.call('ZADD', key, expireAt, member)
|
||||
|
||||
if ttl > 0 then
|
||||
redis.call('PEXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
local count = redis.call('ZCARD', key)
|
||||
return count
|
||||
`
|
||||
|
||||
const count = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl)
|
||||
logger.database(
|
||||
`🔢 Incremented concurrency for key ${apiKeyId}: ${count} (request ${requestId})`
|
||||
)
|
||||
return count
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to increment concurrency:', error)
|
||||
@@ -1557,32 +1594,84 @@ class RedisClient {
|
||||
}
|
||||
}
|
||||
|
||||
// 减少并发计数
|
||||
async decrConcurrency(apiKeyId) {
|
||||
try {
|
||||
const key = `concurrency:${apiKeyId}`
|
||||
// 刷新并发租约,防止长连接提前过期
|
||||
async refreshConcurrencyLease(apiKeyId, requestId, leaseSeconds = null) {
|
||||
if (!requestId) {
|
||||
return 0
|
||||
}
|
||||
|
||||
try {
|
||||
const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } =
|
||||
this._getConcurrencyConfig()
|
||||
const lease = leaseSeconds || defaultLeaseSeconds
|
||||
const key = `concurrency:${apiKeyId}`
|
||||
const now = Date.now()
|
||||
const expireAt = now + lease * 1000
|
||||
const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000)
|
||||
|
||||
// 使用Lua脚本确保原子性操作,防止计数器变成负数
|
||||
const luaScript = `
|
||||
local key = KEYS[1]
|
||||
local current = tonumber(redis.call('get', key) or "0")
|
||||
local member = ARGV[1]
|
||||
local expireAt = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local ttl = tonumber(ARGV[4])
|
||||
|
||||
if current <= 0 then
|
||||
redis.call('del', key)
|
||||
return 0
|
||||
else
|
||||
local new_value = redis.call('decr', key)
|
||||
if new_value <= 0 then
|
||||
redis.call('del', key)
|
||||
return 0
|
||||
else
|
||||
return new_value
|
||||
local exists = redis.call('ZSCORE', key, member)
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
|
||||
|
||||
if exists then
|
||||
redis.call('ZADD', key, expireAt, member)
|
||||
if ttl > 0 then
|
||||
redis.call('PEXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`
|
||||
|
||||
const count = await this.client.eval(luaScript, 1, key)
|
||||
logger.database(`🔢 Decremented concurrency for key ${apiKeyId}: ${count}`)
|
||||
const refreshed = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl)
|
||||
if (refreshed === 1) {
|
||||
logger.debug(`🔄 Refreshed concurrency lease for key ${apiKeyId} (request ${requestId})`)
|
||||
}
|
||||
return refreshed
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to refresh concurrency lease:', error)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// 减少并发计数
|
||||
async decrConcurrency(apiKeyId, requestId) {
|
||||
try {
|
||||
const key = `concurrency:${apiKeyId}`
|
||||
const now = Date.now()
|
||||
|
||||
const luaScript = `
|
||||
local key = KEYS[1]
|
||||
local member = ARGV[1]
|
||||
local now = tonumber(ARGV[2])
|
||||
|
||||
if member then
|
||||
redis.call('ZREM', key, member)
|
||||
end
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
|
||||
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count <= 0 then
|
||||
redis.call('DEL', key)
|
||||
return 0
|
||||
end
|
||||
|
||||
return count
|
||||
`
|
||||
|
||||
const count = await this.client.eval(luaScript, 1, key, requestId || '', now)
|
||||
logger.database(
|
||||
`🔢 Decremented concurrency for key ${apiKeyId}: ${count} (request ${requestId || 'n/a'})`
|
||||
)
|
||||
return count
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to decrement concurrency:', error)
|
||||
@@ -1594,7 +1683,17 @@ class RedisClient {
|
||||
async getConcurrency(apiKeyId) {
|
||||
try {
|
||||
const key = `concurrency:${apiKeyId}`
|
||||
const count = await this.client.get(key)
|
||||
const now = Date.now()
|
||||
|
||||
const luaScript = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
|
||||
return redis.call('ZCARD', key)
|
||||
`
|
||||
|
||||
const count = await this.client.eval(luaScript, 1, key, now)
|
||||
return parseInt(count || 0)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get concurrency:', error)
|
||||
|
||||
Reference in New Issue
Block a user