diff --git a/src/middleware/auth.js b/src/middleware/auth.js index 241ae568..2b55d69f 100644 --- a/src/middleware/auth.js +++ b/src/middleware/auth.js @@ -1,3 +1,5 @@ +const { v4: uuidv4 } = require('uuid') +const config = require('../../config/config') const apiKeyService = require('../services/apiKeyService') const userService = require('../services/userService') const logger = require('../utils/logger') @@ -80,14 +82,33 @@ const authenticateApiKey = async (req, res, next) => { // 检查并发限制 const concurrencyLimit = validation.keyData.concurrencyLimit || 0 if (concurrencyLimit > 0) { - const currentConcurrency = await redis.incrConcurrency(validation.keyData.id) + const concurrencyConfig = config.concurrency || {} + const leaseSeconds = Math.max(concurrencyConfig.leaseSeconds || 900, 30) + const rawRenewInterval = + typeof concurrencyConfig.renewIntervalSeconds === 'number' + ? concurrencyConfig.renewIntervalSeconds + : 60 + let renewIntervalSeconds = rawRenewInterval + if (renewIntervalSeconds > 0) { + const maxSafeRenew = Math.max(leaseSeconds - 5, 15) + renewIntervalSeconds = Math.min(Math.max(renewIntervalSeconds, 15), maxSafeRenew) + } else { + renewIntervalSeconds = 0 + } + const requestId = uuidv4() + + const currentConcurrency = await redis.incrConcurrency( + validation.keyData.id, + requestId, + leaseSeconds + ) logger.api( `📈 Incremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), current: ${currentConcurrency}, limit: ${concurrencyLimit}` ) if (currentConcurrency > concurrencyLimit) { // 如果超过限制,立即减少计数 - await redis.decrConcurrency(validation.keyData.id) + await redis.decrConcurrency(validation.keyData.id, requestId) logger.security( `🚦 Concurrency limit exceeded for key: ${validation.keyData.id} (${ validation.keyData.name @@ -101,14 +122,39 @@ const authenticateApiKey = async (req, res, next) => { }) } + const renewIntervalMs = + renewIntervalSeconds > 0 ? Math.max(renewIntervalSeconds * 1000, 15000) : 0 + // 使用标志位确保只减少一次 let concurrencyDecremented = false + let leaseRenewInterval = null + + if (renewIntervalMs > 0) { + leaseRenewInterval = setInterval(() => { + redis + .refreshConcurrencyLease(validation.keyData.id, requestId, leaseSeconds) + .catch((error) => { + logger.error( + `Failed to refresh concurrency lease for key ${validation.keyData.id}:`, + error + ) + }) + }, renewIntervalMs) + + if (typeof leaseRenewInterval.unref === 'function') { + leaseRenewInterval.unref() + } + } const decrementConcurrency = async () => { if (!concurrencyDecremented) { concurrencyDecremented = true + if (leaseRenewInterval) { + clearInterval(leaseRenewInterval) + leaseRenewInterval = null + } try { - const newCount = await redis.decrConcurrency(validation.keyData.id) + const newCount = await redis.decrConcurrency(validation.keyData.id, requestId) logger.api( `📉 Decremented concurrency for key: ${validation.keyData.id} (${validation.keyData.name}), new count: ${newCount}` ) @@ -147,6 +193,7 @@ const authenticateApiKey = async (req, res, next) => { req.concurrencyInfo = { apiKeyId: validation.keyData.id, apiKeyName: validation.keyData.name, + requestId, decrementConcurrency } } diff --git a/src/models/redis.js b/src/models/redis.js index 65a89b54..668f5a2c 100644 --- a/src/models/redis.js +++ b/src/models/redis.js @@ -1538,18 +1538,55 @@ class RedisClient { } } - // 增加并发计数 - async incrConcurrency(apiKeyId) { + // 获取并发配置 + _getConcurrencyConfig() { + const defaults = { + leaseSeconds: 900, + cleanupGraceSeconds: 30 + } + return { + ...defaults, + ...(config.concurrency || {}) + } + } + + // 增加并发计数(基于租约的有序集合) + async incrConcurrency(apiKeyId, requestId, leaseSeconds = null) { + if (!requestId) { + throw new Error('Request ID is required for concurrency tracking') + } + try { + const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } = + this._getConcurrencyConfig() + const lease = leaseSeconds || defaultLeaseSeconds const key = `concurrency:${apiKeyId}` - const count = await this.client.incr(key) + const now = Date.now() + const expireAt = now + lease * 1000 + const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000) - // 设置过期时间为180秒(3分钟),防止计数器永远不清零 - // 正常情况下请求会在完成时主动减少计数,这只是一个安全保障 - // 180秒足够支持较长的流式请求 - await this.client.expire(key, 180) + const luaScript = ` + local key = KEYS[1] + local member = ARGV[1] + local expireAt = tonumber(ARGV[2]) + local now = tonumber(ARGV[3]) + local ttl = tonumber(ARGV[4]) - logger.database(`🔢 Incremented concurrency for key ${apiKeyId}: ${count}`) + redis.call('ZREMRANGEBYSCORE', key, '-inf', now) + redis.call('ZADD', key, expireAt, member) + + if ttl > 0 then + redis.call('PEXPIRE', key, ttl) + end + + local count = redis.call('ZCARD', key) + return count + ` + + const count = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl) + logger.database( + `🔢 Incremented concurrency for key ${apiKeyId}: ${count} (request ${requestId})` + ) return count } catch (error) { logger.error('❌ Failed to increment concurrency:', error) @@ -1557,32 +1594,84 @@ class RedisClient { } } - // 减少并发计数 - async decrConcurrency(apiKeyId) { - try { - const key = `concurrency:${apiKeyId}` + // 刷新并发租约,防止长连接提前过期 + async refreshConcurrencyLease(apiKeyId, requestId, leaseSeconds = null) { + if (!requestId) { + return 0 + } + + try { + const { leaseSeconds: defaultLeaseSeconds, cleanupGraceSeconds } = + this._getConcurrencyConfig() + const lease = leaseSeconds || defaultLeaseSeconds + const key = `concurrency:${apiKeyId}` + const now = Date.now() + const expireAt = now + lease * 1000 + const ttl = Math.max((lease + cleanupGraceSeconds) * 1000, 60000) - // 使用Lua脚本确保原子性操作,防止计数器变成负数 const luaScript = ` local key = KEYS[1] - local current = tonumber(redis.call('get', key) or "0") + local member = ARGV[1] + local expireAt = tonumber(ARGV[2]) + local now = tonumber(ARGV[3]) + local ttl = tonumber(ARGV[4]) - if current <= 0 then - redis.call('del', key) - return 0 - else - local new_value = redis.call('decr', key) - if new_value <= 0 then - redis.call('del', key) - return 0 - else - return new_value + local exists = redis.call('ZSCORE', key, member) + + redis.call('ZREMRANGEBYSCORE', key, '-inf', now) + + if exists then + redis.call('ZADD', key, expireAt, member) + if ttl > 0 then + redis.call('PEXPIRE', key, ttl) end + return 1 end + + return 0 ` - const count = await this.client.eval(luaScript, 1, key) - logger.database(`🔢 Decremented concurrency for key ${apiKeyId}: ${count}`) + const refreshed = await this.client.eval(luaScript, 1, key, requestId, expireAt, now, ttl) + if (refreshed === 1) { + logger.debug(`🔄 Refreshed concurrency lease for key ${apiKeyId} (request ${requestId})`) + } + return refreshed + } catch (error) { + logger.error('❌ Failed to refresh concurrency lease:', error) + return 0 + } + } + + // 减少并发计数 + async decrConcurrency(apiKeyId, requestId) { + try { + const key = `concurrency:${apiKeyId}` + const now = Date.now() + + const luaScript = ` + local key = KEYS[1] + local member = ARGV[1] + local now = tonumber(ARGV[2]) + + if member then + redis.call('ZREM', key, member) + end + + redis.call('ZREMRANGEBYSCORE', key, '-inf', now) + + local count = redis.call('ZCARD', key) + if count <= 0 then + redis.call('DEL', key) + return 0 + end + + return count + ` + + const count = await this.client.eval(luaScript, 1, key, requestId || '', now) + logger.database( + `🔢 Decremented concurrency for key ${apiKeyId}: ${count} (request ${requestId || 'n/a'})` + ) return count } catch (error) { logger.error('❌ Failed to decrement concurrency:', error) @@ -1594,7 +1683,17 @@ class RedisClient { async getConcurrency(apiKeyId) { try { const key = `concurrency:${apiKeyId}` - const count = await this.client.get(key) + const now = Date.now() + + const luaScript = ` + local key = KEYS[1] + local now = tonumber(ARGV[1]) + + redis.call('ZREMRANGEBYSCORE', key, '-inf', now) + return redis.call('ZCARD', key) + ` + + const count = await this.client.eval(luaScript, 1, key, now) return parseInt(count || 0) } catch (error) { logger.error('❌ Failed to get concurrency:', error)