fix: 优化并发限制数的控制逻辑

This commit is contained in:
shaw
2025-09-28 13:58:59 +08:00
parent 5ce385d2bc
commit 90dce32cfc
2 changed files with 176 additions and 30 deletions

View File

@@ -1,3 +1,5 @@
const { v4: uuidv4 } = require('uuid')
const config = require('../../config/config')
const apiKeyService = require('../services/apiKeyService')
const userService = require('../services/userService')
const logger = require('../utils/logger')
@@ -80,14 +82,33 @@ const authenticateApiKey = async (req, res, next) => {
// 检查并发限制
const concurrencyLimit = validation.keyData.concurrencyLimit || 0
if (concurrencyLimit > 0) {
const currentConcurrency = await redis.incrConcurrency(validation.keyData.id)
const concurrencyConfig = config.concurrency || {}
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
const rawRenewInterval =
typeof concurrencyConfig.renewIntervalSeconds === 'number'
? concurrencyConfig.renewIntervalSeconds
: 60
let renewIntervalSeconds = rawRenewInterval
if (renewIntervalSeconds > 0) {
const maxSafeRenew = Math.max(leaseSeconds - 5, 15)
renewIntervalSeconds = Math.min(Math.max(renewIntervalSeconds, 15), maxSafeRenew)
} else {
renewIntervalSeconds = 0
}
const requestId = uuidv4()
const currentConcurrency = await redis.incrConcurrency(
validation.keyData.id,
requestId,
leaseSeconds
)
logger.api(
`📈 Incremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), current: ${currentConcurrency}, limit: ${concurrencyLimit}`
)
if (currentConcurrency > concurrencyLimit) {
// 如果超过限制,立即减少计数
await redis.decrConcurrency(validation.keyData.id)
await redis.decrConcurrency(validation.keyData.id, requestId)
logger.security(
`🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${
validation.keyData.name
@@ -101,14 +122,39 @@ const authenticateApiKey = async (req, res, next) => {
})
}
const renewIntervalMs =
renewIntervalSeconds > 0 ? Math.max(renewIntervalSeconds * 1000, 15000) : 0
// 使用标志位确保只减少一次
let concurrencyDecremented = false
let leaseRenewInterval = null
if (renewIntervalMs > 0) {
leaseRenewInterval = setInterval(() => {
redis
.refreshConcurrencyLease(validation.keyData.id, requestId, leaseSeconds)
.catch((error) => {
logger.error(
`Failed to refresh concurrency lease for key ${validation.keyData.id}:`,
error
)
})
}, renewIntervalMs)
if (typeof leaseRenewInterval.unref === 'function') {
leaseRenewInterval.unref()
}
}
const decrementConcurrency = async () => {
if (!concurrencyDecremented) {
concurrencyDecremented = true
if (leaseRenewInterval) {
clearInterval(leaseRenewInterval)
leaseRenewInterval = null
}
try {
const newCount = await redis.decrConcurrency(validation.keyData.id)
const newCount = await redis.decrConcurrency(validation.keyData.id, requestId)
logger.api(
`📉 Decremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), new count: ${newCount}`
)
@@ -147,6 +193,7 @@ const authenticateApiKey = async (req, res, next) => {
req.concurrencyInfo = {
apiKeyId: validation.keyData.id,
apiKeyName: validation.keyData.name,
requestId,
decrementConcurrency
}
}

View File

@@ -1538,18 +1538,55 @@ class RedisClient {
}
}
// 增加并发计数
async incrConcurrency(apiKeyId) {
// 获取并发配置
_getConcurrencyConfig() {
const defaults = {
leaseSeconds: 900,
cleanupGraceSeconds: 30
}
return {
...defaults,
...(config.concurrency || {})
}
}
// 增加并发计数(基于租约的有序集合)
async incrConcurrency(apiKeyId, requestId, leaseSeconds = null) {
if (!requestId) {
throw new Error('Request ID is required for concurrency tracking')
}
try {
const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } =
this._getConcurrencyConfig()
const lease = leaseSeconds || defaultLeaseSeconds
const key = `concurrency:${apiKeyId}`
const count = await this.client.incr(key)
const now = Date.now()
const expireAt = now + lease * 1000
const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000)
// 设置过期时间为180秒3分钟防止计数器永远不清零
// 正常情况下请求会在完成时主动减少计数,这只是一个安全保障
// 180秒足够支持较长的流式请求
await this.client.expire(key, 180)
const luaScript = `
local key = KEYS[1]
local member = ARGV[1]
local expireAt = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
logger.database(`🔢 Incremented concurrency for key ${apiKeyId}: ${count}`)
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
redis.call('ZADD', key, expireAt, member)
if ttl > 0 then
redis.call('PEXPIRE', key, ttl)
end
local count = redis.call('ZCARD', key)
return count
`
const count = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl)
logger.database(
`🔢 Incremented concurrency for key ${apiKeyId}: ${count} (request ${requestId})`
)
return count
} catch (error) {
logger.error('❌ Failed to increment concurrency:', error)
@@ -1557,32 +1594,84 @@ class RedisClient {
}
}
// 减少并发计数
async decrConcurrency(apiKeyId) {
try {
const key = `concurrency:${apiKeyId}`
// 刷新并发租约,防止长连接提前过期
async refreshConcurrencyLease(apiKeyId, requestId, leaseSeconds = null) {
if (!requestId) {
return 0
}
try {
const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } =
this._getConcurrencyConfig()
const lease = leaseSeconds || defaultLeaseSeconds
const key = `concurrency:${apiKeyId}`
const now = Date.now()
const expireAt = now + lease * 1000
const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000)
// 使用Lua脚本确保原子性操作防止计数器变成负数
const luaScript = `
local key = KEYS[1]
local current = tonumber(redis.call('get', key) or "0")
local member = ARGV[1]
local expireAt = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
if current <= 0 then
redis.call('del', key)
return 0
else
local new_value = redis.call('decr', key)
if new_value <= 0 then
redis.call('del', key)
return 0
else
return new_value
local exists = redis.call('ZSCORE', key, member)
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
if exists then
redis.call('ZADD', key, expireAt, member)
if ttl > 0 then
redis.call('PEXPIRE', key, ttl)
end
return 1
end
return 0
`
const count = await this.client.eval(luaScript, 1, key)
logger.database(`🔢 Decremented concurrency for key ${apiKeyId}: ${count}`)
const refreshed = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl)
if (refreshed === 1) {
logger.debug(`🔄 Refreshed concurrency lease for key ${apiKeyId} (request ${requestId})`)
}
return refreshed
} catch (error) {
logger.error('❌ Failed to refresh concurrency lease:', error)
return 0
}
}
// 减少并发计数
async decrConcurrency(apiKeyId, requestId) {
try {
const key = `concurrency:${apiKeyId}`
const now = Date.now()
const luaScript = `
local key = KEYS[1]
local member = ARGV[1]
local now = tonumber(ARGV[2])
if member then
redis.call('ZREM', key, member)
end
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
local count = redis.call('ZCARD', key)
if count <= 0 then
redis.call('DEL', key)
return 0
end
return count
`
const count = await this.client.eval(luaScript, 1, key, requestId || '', now)
logger.database(
`🔢 Decremented concurrency for key ${apiKeyId}: ${count} (request ${requestId || 'n/a'})`
)
return count
} catch (error) {
logger.error('❌ Failed to decrement concurrency:', error)
@@ -1594,7 +1683,17 @@ class RedisClient {
async getConcurrency(apiKeyId) {
try {
const key = `concurrency:${apiKeyId}`
const count = await this.client.get(key)
const now = Date.now()
const luaScript = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
return redis.call('ZCARD', key)
`
const count = await this.client.eval(luaScript, 1, key, now)
return parseInt(count || 0)
} catch (error) {
logger.error('❌ Failed to get concurrency:', error)