fix: 优化并发限制数的控制逻辑

This commit is contained in:
shaw
2025-09-28 13:58:59 +08:00
parent 5ce385d2bc
commit 90dce32cfc
2 changed files with 176 additions and 30 deletions

View File

@@ -1,3 +1,5 @@
const { v4: uuidv4 } = require('uuid')
const config = require('../../config/config')
const apiKeyService = require('../services/apiKeyService') const apiKeyService = require('../services/apiKeyService')
const userService = require('../services/userService') const userService = require('../services/userService')
const logger = require('../utils/logger') const logger = require('../utils/logger')
@@ -80,14 +82,33 @@ const authenticateApiKey = async (req, res, next) => {
// 检查并发限制 // 检查并发限制
const concurrencyLimit = validation.keyData.concurrencyLimit || 0 const concurrencyLimit = validation.keyData.concurrencyLimit || 0
if (concurrencyLimit > 0) { if (concurrencyLimit > 0) {
const currentConcurrency = await redis.incrConcurrency(validation.keyData.id) const concurrencyConfig = config.concurrency || {}
const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30)
const rawRenewInterval =
typeof concurrencyConfig.renewIntervalSeconds === 'number'
? concurrencyConfig.renewIntervalSeconds
: 60
let renewIntervalSeconds = rawRenewInterval
if (renewIntervalSeconds > 0) {
const maxSafeRenew = Math.max(leaseSeconds - 5, 15)
renewIntervalSeconds = Math.min(Math.max(renewIntervalSeconds, 15), maxSafeRenew)
} else {
renewIntervalSeconds = 0
}
const requestId = uuidv4()
const currentConcurrency = await redis.incrConcurrency(
validation.keyData.id,
requestId,
leaseSeconds
)
logger.api( logger.api(
`📈 Incremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), current: ${currentConcurrency}, limit: ${concurrencyLimit}` `📈 Incremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), current: ${currentConcurrency}, limit: ${concurrencyLimit}`
) )
if (currentConcurrency > concurrencyLimit) { if (currentConcurrency > concurrencyLimit) {
// 如果超过限制,立即减少计数 // 如果超过限制,立即减少计数
await redis.decrConcurrency(validation.keyData.id) await redis.decrConcurrency(validation.keyData.id, requestId)
logger.security( logger.security(
`🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${ `🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${
validation.keyData.name validation.keyData.name
@@ -101,14 +122,39 @@ const authenticateApiKey = async (req, res, next) => {
}) })
} }
const renewIntervalMs =
renewIntervalSeconds > 0 ? Math.max(renewIntervalSeconds * 1000, 15000) : 0
// 使用标志位确保只减少一次 // 使用标志位确保只减少一次
let concurrencyDecremented = false let concurrencyDecremented = false
let leaseRenewInterval = null
if (renewIntervalMs > 0) {
leaseRenewInterval = setInterval(() => {
redis
.refreshConcurrencyLease(validation.keyData.id, requestId, leaseSeconds)
.catch((error) => {
logger.error(
`Failed to refresh concurrency lease for key ${validation.keyData.id}:`,
error
)
})
}, renewIntervalMs)
if (typeof leaseRenewInterval.unref === 'function') {
leaseRenewInterval.unref()
}
}
const decrementConcurrency = async () => { const decrementConcurrency = async () => {
if (!concurrencyDecremented) { if (!concurrencyDecremented) {
concurrencyDecremented = true concurrencyDecremented = true
if (leaseRenewInterval) {
clearInterval(leaseRenewInterval)
leaseRenewInterval = null
}
try { try {
const newCount = await redis.decrConcurrency(validation.keyData.id) const newCount = await redis.decrConcurrency(validation.keyData.id, requestId)
logger.api( logger.api(
`📉 Decremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), new count: ${newCount}` `📉 Decremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), new count: ${newCount}`
) )
@@ -147,6 +193,7 @@ const authenticateApiKey = async (req, res, next) => {
req.concurrencyInfo = { req.concurrencyInfo = {
apiKeyId: validation.keyData.id, apiKeyId: validation.keyData.id,
apiKeyName: validation.keyData.name, apiKeyName: validation.keyData.name,
requestId,
decrementConcurrency decrementConcurrency
} }
} }

View File

@@ -1538,18 +1538,55 @@ class RedisClient {
} }
} }
// 增加并发计数 // 获取并发配置
async incrConcurrency(apiKeyId) { _getConcurrencyConfig() {
const defaults = {
leaseSeconds: 900,
cleanupGraceSeconds: 30
}
return {
...defaults,
...(config.concurrency || {})
}
}
// 增加并发计数(基于租约的有序集合)
async incrConcurrency(apiKeyId, requestId, leaseSeconds = null) {
if (!requestId) {
throw new Error('Request ID is required for concurrency tracking')
}
try { try {
const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } =
this._getConcurrencyConfig()
const lease = leaseSeconds || defaultLeaseSeconds
const key = `concurrency:${apiKeyId}` const key = `concurrency:${apiKeyId}`
const count = await this.client.incr(key) const now = Date.now()
const expireAt = now + lease * 1000
const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000)
// 设置过期时间为180秒3分钟防止计数器永远不清零 const luaScript = `
// 正常情况下请求会在完成时主动减少计数,这只是一个安全保障 local key = KEYS[1]
// 180秒足够支持较长的流式请求 local member = ARGV[1]
await this.client.expire(key, 180) local expireAt = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
logger.database(`🔢 Incremented concurrency for key ${apiKeyId}: ${count}`) redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
redis.call('ZADD', key, expireAt, member)
if ttl > 0 then
redis.call('PEXPIRE', key, ttl)
end
local count = redis.call('ZCARD', key)
return count
`
const count = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl)
logger.database(
`🔢 Incremented concurrency for key ${apiKeyId}: ${count} (request ${requestId})`
)
return count return count
} catch (error) { } catch (error) {
logger.error('❌ Failed to increment concurrency:', error) logger.error('❌ Failed to increment concurrency:', error)
@@ -1557,32 +1594,84 @@ class RedisClient {
} }
} }
// 减少并发计数 // 刷新并发租约,防止长连接提前过期
async decrConcurrency(apiKeyId) { async refreshConcurrencyLease(apiKeyId, requestId, leaseSeconds = null) {
try { if (!requestId) {
const key = `concurrency:${apiKeyId}` return 0
}
try {
const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } =
this._getConcurrencyConfig()
const lease = leaseSeconds || defaultLeaseSeconds
const key = `concurrency:${apiKeyId}`
const now = Date.now()
const expireAt = now + lease * 1000
const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000)
// 使用Lua脚本确保原子性操作防止计数器变成负数
const luaScript = ` const luaScript = `
local key = KEYS[1] local key = KEYS[1]
local current = tonumber(redis.call('get', key) or "0") local member = ARGV[1]
local expireAt = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local ttl = tonumber(ARGV[4])
if current <= 0 then local exists = redis.call('ZSCORE', key, member)
redis.call('del', key)
return 0 redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
else
local new_value = redis.call('decr', key) if exists then
if new_value <= 0 then redis.call('ZADD', key, expireAt, member)
redis.call('del', key) if ttl > 0 then
return 0 redis.call('PEXPIRE', key, ttl)
else
return new_value
end end
return 1
end end
return 0
` `
const count = await this.client.eval(luaScript, 1, key) const refreshed = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl)
logger.database(`🔢 Decremented concurrency for key ${apiKeyId}: ${count}`) if (refreshed === 1) {
logger.debug(`🔄 Refreshed concurrency lease for key ${apiKeyId} (request ${requestId})`)
}
return refreshed
} catch (error) {
logger.error('❌ Failed to refresh concurrency lease:', error)
return 0
}
}
// 减少并发计数
async decrConcurrency(apiKeyId, requestId) {
try {
const key = `concurrency:${apiKeyId}`
const now = Date.now()
const luaScript = `
local key = KEYS[1]
local member = ARGV[1]
local now = tonumber(ARGV[2])
if member then
redis.call('ZREM', key, member)
end
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
local count = redis.call('ZCARD', key)
if count <= 0 then
redis.call('DEL', key)
return 0
end
return count
`
const count = await this.client.eval(luaScript, 1, key, requestId || '', now)
logger.database(
`🔢 Decremented concurrency for key ${apiKeyId}: ${count} (request ${requestId || 'n/a'})`
)
return count return count
} catch (error) { } catch (error) {
logger.error('❌ Failed to decrement concurrency:', error) logger.error('❌ Failed to decrement concurrency:', error)
@@ -1594,7 +1683,17 @@ class RedisClient {
async getConcurrency(apiKeyId) { async getConcurrency(apiKeyId) {
try { try {
const key = `concurrency:${apiKeyId}` const key = `concurrency:${apiKeyId}`
const count = await this.client.get(key) const now = Date.now()
const luaScript = `
local key = KEYS[1]
local now = tonumber(ARGV[1])
redis.call('ZREMRANGEBYSCORE', key, '-inf', now)
return redis.call('ZCARD', key)
`
const count = await this.client.eval(luaScript, 1, key, now)
return parseInt(count || 0) return parseInt(count || 0)
} catch (error) { } catch (error) {
logger.error('❌ Failed to get concurrency:', error) logger.error('❌ Failed to get concurrency:', error)