mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 09:38:02 +00:00
@@ -70,7 +70,9 @@ class ApiKeyService {
|
||||
createdAt: new Date().toISOString(),
|
||||
lastUsedAt: '',
|
||||
expiresAt: expiresAt || '',
|
||||
createdBy: 'admin' // 可以根据需要扩展用户系统
|
||||
createdBy: options.createdBy || 'admin',
|
||||
userId: options.userId || '',
|
||||
userUsername: options.userUsername || ''
|
||||
}
|
||||
|
||||
// 保存API Key数据并建立哈希映射
|
||||
@@ -136,6 +138,20 @@ class ApiKeyService {
|
||||
return { valid: false, error: 'API key has expired' }
|
||||
}
|
||||
|
||||
// 如果API Key属于某个用户,检查用户是否被禁用
|
||||
if (keyData.userId) {
|
||||
try {
|
||||
const userService = require('./userService')
|
||||
const user = await userService.getUserById(keyData.userId, false)
|
||||
if (!user || !user.isActive) {
|
||||
return { valid: false, error: 'User account is disabled' }
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Error checking user status during API key validation:', error)
|
||||
return { valid: false, error: 'Unable to validate user status' }
|
||||
}
|
||||
}
|
||||
|
||||
// 获取使用统计(供返回数据使用)
|
||||
const usage = await redis.getUsageStats(keyData.id)
|
||||
|
||||
@@ -210,14 +226,27 @@ class ApiKeyService {
|
||||
}
|
||||
|
||||
// 📋 获取所有API Keys
|
||||
async getAllApiKeys() {
|
||||
async getAllApiKeys(includeDeleted = false) {
|
||||
try {
|
||||
const apiKeys = await redis.getAllApiKeys()
|
||||
let apiKeys = await redis.getAllApiKeys()
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 默认过滤掉已删除的API Keys
|
||||
if (!includeDeleted) {
|
||||
apiKeys = apiKeys.filter((key) => key.isDeleted !== 'true')
|
||||
}
|
||||
|
||||
// 为每个key添加使用统计和当前并发数
|
||||
for (const key of apiKeys) {
|
||||
key.usage = await redis.getUsageStats(key.id)
|
||||
const costStats = await redis.getCostStats(key.id)
|
||||
// Add cost information to usage object for frontend compatibility
|
||||
if (key.usage && costStats) {
|
||||
key.usage.total = key.usage.total || {}
|
||||
key.usage.total.cost = costStats.total
|
||||
key.usage.totalCost = costStats.total
|
||||
}
|
||||
key.totalCost = costStats ? costStats.total : 0
|
||||
key.tokenLimit = parseInt(key.tokenLimit)
|
||||
key.concurrencyLimit = parseInt(key.concurrencyLimit || 0)
|
||||
key.rateLimitWindow = parseInt(key.rateLimitWindow || 0)
|
||||
@@ -371,16 +400,32 @@ class ApiKeyService {
|
||||
}
|
||||
}
|
||||
|
||||
// 🗑️ 删除API Key
|
||||
async deleteApiKey(keyId) {
|
||||
// 🗑️ 软删除API Key (保留使用统计)
|
||||
async deleteApiKey(keyId, deletedBy = 'system', deletedByType = 'system') {
|
||||
try {
|
||||
const result = await redis.deleteApiKey(keyId)
|
||||
|
||||
if (result === 0) {
|
||||
const keyData = await redis.getApiKey(keyId)
|
||||
if (!keyData || Object.keys(keyData).length === 0) {
|
||||
throw new Error('API key not found')
|
||||
}
|
||||
|
||||
logger.success(`🗑️ Deleted API key: ${keyId}`)
|
||||
// 标记为已删除,保留所有数据和统计信息
|
||||
const updatedData = {
|
||||
...keyData,
|
||||
isDeleted: 'true',
|
||||
deletedAt: new Date().toISOString(),
|
||||
deletedBy,
|
||||
deletedByType, // 'user', 'admin', 'system'
|
||||
isActive: 'false' // 同时禁用
|
||||
}
|
||||
|
||||
await redis.setApiKey(keyId, updatedData)
|
||||
|
||||
// 从哈希映射中移除(这样就不能再使用这个key进行API调用)
|
||||
if (keyData.apiKey) {
|
||||
await redis.deleteApiKeyHash(keyData.apiKey)
|
||||
}
|
||||
|
||||
logger.success(`🗑️ Soft deleted API key: ${keyId} by ${deletedBy} (${deletedByType})`)
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
@@ -672,6 +717,225 @@ class ApiKeyService {
|
||||
return await redis.getAllAccountsUsageStats()
|
||||
}
|
||||
|
||||
// === 用户相关方法 ===
|
||||
|
||||
// 🔑 创建API Key(支持用户)
|
||||
async createApiKey(options = {}) {
|
||||
return await this.generateApiKey(options)
|
||||
}
|
||||
|
||||
// 👤 获取用户的API Keys
|
||||
async getUserApiKeys(userId, includeDeleted = false) {
|
||||
try {
|
||||
const allKeys = await redis.getAllApiKeys()
|
||||
let userKeys = allKeys.filter((key) => key.userId === userId)
|
||||
|
||||
// 默认过滤掉已删除的API Keys
|
||||
if (!includeDeleted) {
|
||||
userKeys = userKeys.filter((key) => key.isDeleted !== 'true')
|
||||
}
|
||||
|
||||
// Populate usage stats for each user's API key (same as getAllApiKeys does)
|
||||
const userKeysWithUsage = []
|
||||
for (const key of userKeys) {
|
||||
const usage = await redis.getUsageStats(key.id)
|
||||
const dailyCost = (await redis.getDailyCost(key.id)) || 0
|
||||
const costStats = await redis.getCostStats(key.id)
|
||||
|
||||
userKeysWithUsage.push({
|
||||
id: key.id,
|
||||
name: key.name,
|
||||
description: key.description,
|
||||
key: key.apiKey ? `${this.prefix}****${key.apiKey.slice(-4)}` : null, // 只显示前缀和后4位
|
||||
tokenLimit: parseInt(key.tokenLimit || 0),
|
||||
isActive: key.isActive === 'true',
|
||||
createdAt: key.createdAt,
|
||||
lastUsedAt: key.lastUsedAt,
|
||||
expiresAt: key.expiresAt,
|
||||
usage,
|
||||
dailyCost,
|
||||
totalCost: costStats.total,
|
||||
dailyCostLimit: parseFloat(key.dailyCostLimit || 0),
|
||||
userId: key.userId,
|
||||
userUsername: key.userUsername,
|
||||
createdBy: key.createdBy,
|
||||
// Include deletion fields for deleted keys
|
||||
isDeleted: key.isDeleted,
|
||||
deletedAt: key.deletedAt,
|
||||
deletedBy: key.deletedBy,
|
||||
deletedByType: key.deletedByType
|
||||
})
|
||||
}
|
||||
|
||||
return userKeysWithUsage
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get user API keys:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 通过ID获取API Key(检查权限)
|
||||
async getApiKeyById(keyId, userId = null) {
|
||||
try {
|
||||
const keyData = await redis.getApiKey(keyId)
|
||||
if (!keyData) {
|
||||
return null
|
||||
}
|
||||
|
||||
// 如果指定了用户ID,检查权限
|
||||
if (userId && keyData.userId !== userId) {
|
||||
return null
|
||||
}
|
||||
|
||||
return {
|
||||
id: keyData.id,
|
||||
name: keyData.name,
|
||||
description: keyData.description,
|
||||
key: keyData.apiKey,
|
||||
tokenLimit: parseInt(keyData.tokenLimit || 0),
|
||||
isActive: keyData.isActive === 'true',
|
||||
createdAt: keyData.createdAt,
|
||||
lastUsedAt: keyData.lastUsedAt,
|
||||
expiresAt: keyData.expiresAt,
|
||||
userId: keyData.userId,
|
||||
userUsername: keyData.userUsername,
|
||||
createdBy: keyData.createdBy,
|
||||
permissions: keyData.permissions,
|
||||
dailyCostLimit: parseFloat(keyData.dailyCostLimit || 0)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get API key by ID:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
// 🔄 重新生成API Key
|
||||
async regenerateApiKey(keyId) {
|
||||
try {
|
||||
const existingKey = await redis.getApiKey(keyId)
|
||||
if (!existingKey) {
|
||||
throw new Error('API key not found')
|
||||
}
|
||||
|
||||
// 生成新的key
|
||||
const newApiKey = `${this.prefix}${this._generateSecretKey()}`
|
||||
const newHashedKey = this._hashApiKey(newApiKey)
|
||||
|
||||
// 删除旧的哈希映射
|
||||
const oldHashedKey = existingKey.apiKey
|
||||
await redis.deleteApiKeyHash(oldHashedKey)
|
||||
|
||||
// 更新key数据
|
||||
const updatedKeyData = {
|
||||
...existingKey,
|
||||
apiKey: newHashedKey,
|
||||
updatedAt: new Date().toISOString()
|
||||
}
|
||||
|
||||
// 保存新数据并建立新的哈希映射
|
||||
await redis.setApiKey(keyId, updatedKeyData, newHashedKey)
|
||||
|
||||
logger.info(`🔄 Regenerated API key: ${existingKey.name} (${keyId})`)
|
||||
|
||||
return {
|
||||
id: keyId,
|
||||
name: existingKey.name,
|
||||
key: newApiKey, // 返回完整的新key
|
||||
updatedAt: updatedKeyData.updatedAt
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to regenerate API key:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🗑️ 硬删除API Key (完全移除)
|
||||
async hardDeleteApiKey(keyId) {
|
||||
try {
|
||||
const keyData = await redis.getApiKey(keyId)
|
||||
if (!keyData) {
|
||||
throw new Error('API key not found')
|
||||
}
|
||||
|
||||
// 删除key数据和哈希映射
|
||||
await redis.deleteApiKey(keyId)
|
||||
await redis.deleteApiKeyHash(keyData.apiKey)
|
||||
|
||||
logger.info(`🗑️ Deleted API key: ${keyData.name} (${keyId})`)
|
||||
return true
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to delete API key:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🚫 禁用用户的所有API Keys
|
||||
async disableUserApiKeys(userId) {
|
||||
try {
|
||||
const userKeys = await this.getUserApiKeys(userId)
|
||||
let disabledCount = 0
|
||||
|
||||
for (const key of userKeys) {
|
||||
if (key.isActive) {
|
||||
await this.updateApiKey(key.id, { isActive: false })
|
||||
disabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`🚫 Disabled ${disabledCount} API keys for user: ${userId}`)
|
||||
return { count: disabledCount }
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to disable user API keys:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 获取聚合使用统计(支持多个API Key)
|
||||
async getAggregatedUsageStats(keyIds, options = {}) {
|
||||
try {
|
||||
if (!Array.isArray(keyIds)) {
|
||||
keyIds = [keyIds]
|
||||
}
|
||||
|
||||
const { period: _period = 'week', model: _model } = options
|
||||
const stats = {
|
||||
totalRequests: 0,
|
||||
totalInputTokens: 0,
|
||||
totalOutputTokens: 0,
|
||||
totalCost: 0,
|
||||
dailyStats: [],
|
||||
modelStats: []
|
||||
}
|
||||
|
||||
// 汇总所有API Key的统计数据
|
||||
for (const keyId of keyIds) {
|
||||
const keyStats = await redis.getUsageStats(keyId)
|
||||
const costStats = await redis.getCostStats(keyId)
|
||||
if (keyStats && keyStats.total) {
|
||||
stats.totalRequests += keyStats.total.requests || 0
|
||||
stats.totalInputTokens += keyStats.total.inputTokens || 0
|
||||
stats.totalOutputTokens += keyStats.total.outputTokens || 0
|
||||
stats.totalCost += costStats?.total || 0
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 实现日期范围和模型统计
|
||||
// 这里可以根据需要添加更详细的统计逻辑
|
||||
|
||||
return stats
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get usage stats:', error)
|
||||
return {
|
||||
totalRequests: 0,
|
||||
totalInputTokens: 0,
|
||||
totalOutputTokens: 0,
|
||||
totalCost: 0,
|
||||
dailyStats: [],
|
||||
modelStats: []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 🧹 清理过期的API Keys
|
||||
async cleanupExpiredKeys() {
|
||||
try {
|
||||
|
||||
@@ -273,6 +273,11 @@ function handleStreamResponse(upstreamResponse, clientResponse, options = {}) {
|
||||
let eventCount = 0
|
||||
const maxEvents = 10000 // 最大事件数量限制
|
||||
|
||||
// 专门用于保存最后几个chunks以提取usage数据
|
||||
let finalChunksBuffer = ''
|
||||
const FINAL_CHUNKS_SIZE = 32 * 1024 // 32KB保留最终chunks
|
||||
const allParsedEvents = [] // 存储所有解析的事件用于最终usage提取
|
||||
|
||||
// 设置响应头
|
||||
clientResponse.setHeader('Content-Type', 'text/event-stream')
|
||||
clientResponse.setHeader('Cache-Control', 'no-cache')
|
||||
@@ -297,8 +302,8 @@ function handleStreamResponse(upstreamResponse, clientResponse, options = {}) {
|
||||
clientResponse.flushHeaders()
|
||||
}
|
||||
|
||||
// 解析 SSE 事件以捕获 usage 数据
|
||||
const parseSSEForUsage = (data) => {
|
||||
// 强化的SSE事件解析,保存所有事件用于最终处理
|
||||
const parseSSEForUsage = (data, isFromFinalBuffer = false) => {
|
||||
const lines = data.split('\n')
|
||||
|
||||
for (const line of lines) {
|
||||
@@ -310,34 +315,54 @@ function handleStreamResponse(upstreamResponse, clientResponse, options = {}) {
|
||||
}
|
||||
const eventData = JSON.parse(jsonStr)
|
||||
|
||||
// 保存所有成功解析的事件
|
||||
allParsedEvents.push(eventData)
|
||||
|
||||
// 获取模型信息
|
||||
if (eventData.model) {
|
||||
actualModel = eventData.model
|
||||
}
|
||||
|
||||
// 获取使用统计(Responses API: response.completed -> response.usage)
|
||||
if (eventData.type === 'response.completed' && eventData.response) {
|
||||
if (eventData.response.model) {
|
||||
actualModel = eventData.response.model
|
||||
}
|
||||
if (eventData.response.usage) {
|
||||
usageData = eventData.response.usage
|
||||
logger.debug('Captured Azure OpenAI nested usage (response.usage):', usageData)
|
||||
// 使用强化的usage提取函数
|
||||
const { usageData: extractedUsage, actualModel: extractedModel } =
|
||||
extractUsageDataRobust(
|
||||
eventData,
|
||||
`stream-event-${isFromFinalBuffer ? 'final' : 'normal'}`
|
||||
)
|
||||
|
||||
if (extractedUsage && !usageData) {
|
||||
usageData = extractedUsage
|
||||
if (extractedModel) {
|
||||
actualModel = extractedModel
|
||||
}
|
||||
logger.debug(`🎯 Stream usage captured via robust extraction`, {
|
||||
isFromFinalBuffer,
|
||||
usageData,
|
||||
actualModel
|
||||
})
|
||||
}
|
||||
|
||||
// 兼容 Chat Completions 风格(顶层 usage)
|
||||
if (!usageData && eventData.usage) {
|
||||
usageData = eventData.usage
|
||||
logger.debug('Captured Azure OpenAI usage (top-level):', usageData)
|
||||
}
|
||||
// 原有的简单提取作为备用
|
||||
if (!usageData) {
|
||||
// 获取使用统计(Responses API: response.completed -> response.usage)
|
||||
if (eventData.type === 'response.completed' && eventData.response) {
|
||||
if (eventData.response.model) {
|
||||
actualModel = eventData.response.model
|
||||
}
|
||||
if (eventData.response.usage) {
|
||||
usageData = eventData.response.usage
|
||||
logger.debug('🎯 Stream usage (backup method - response.usage):', usageData)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否是完成事件
|
||||
if (eventData.choices && eventData.choices[0] && eventData.choices[0].finish_reason) {
|
||||
// 这是最后一个 chunk
|
||||
// 兼容 Chat Completions 风格(顶层 usage)
|
||||
if (!usageData && eventData.usage) {
|
||||
usageData = eventData.usage
|
||||
logger.debug('🎯 Stream usage (backup method - top-level):', usageData)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// 忽略解析错误
|
||||
logger.debug('SSE parsing error (expected for incomplete chunks):', e.message)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -387,10 +412,19 @@ function handleStreamResponse(upstreamResponse, clientResponse, options = {}) {
|
||||
// 同时解析数据以捕获 usage 信息,带缓冲区大小限制
|
||||
buffer += chunkStr
|
||||
|
||||
// 防止缓冲区过大
|
||||
// 保留最后的chunks用于最终usage提取(不被truncate影响)
|
||||
finalChunksBuffer += chunkStr
|
||||
if (finalChunksBuffer.length > FINAL_CHUNKS_SIZE) {
|
||||
finalChunksBuffer = finalChunksBuffer.slice(-FINAL_CHUNKS_SIZE)
|
||||
}
|
||||
|
||||
// 防止主缓冲区过大 - 但保持最后部分用于usage解析
|
||||
if (buffer.length > MAX_BUFFER_SIZE) {
|
||||
logger.warn(`Stream ${streamId} buffer exceeded limit, truncating`)
|
||||
buffer = buffer.slice(-MAX_BUFFER_SIZE / 2) // 保留后一半
|
||||
logger.warn(
|
||||
`Stream ${streamId} buffer exceeded limit, truncating main buffer but preserving final chunks`
|
||||
)
|
||||
// 保留最后1/4而不是1/2,为usage数据留更多空间
|
||||
buffer = buffer.slice(-MAX_BUFFER_SIZE / 4)
|
||||
}
|
||||
|
||||
// 处理完整的 SSE 事件
|
||||
@@ -426,9 +460,91 @@ function handleStreamResponse(upstreamResponse, clientResponse, options = {}) {
|
||||
hasEnded = true
|
||||
|
||||
try {
|
||||
// 处理剩余的 buffer
|
||||
if (buffer.trim() && buffer.length <= MAX_EVENT_SIZE) {
|
||||
parseSSEForUsage(buffer)
|
||||
logger.debug(`🔚 Stream ended, performing comprehensive usage extraction for ${streamId}`, {
|
||||
mainBufferSize: buffer.length,
|
||||
finalChunksBufferSize: finalChunksBuffer.length,
|
||||
parsedEventsCount: allParsedEvents.length,
|
||||
hasUsageData: !!usageData
|
||||
})
|
||||
|
||||
// 多层次的最终usage提取策略
|
||||
if (!usageData) {
|
||||
logger.debug('🔍 No usage found during stream, trying final extraction methods...')
|
||||
|
||||
// 方法1: 解析剩余的主buffer
|
||||
if (buffer.trim() && buffer.length <= MAX_EVENT_SIZE) {
|
||||
parseSSEForUsage(buffer, false)
|
||||
}
|
||||
|
||||
// 方法2: 解析保留的final chunks buffer
|
||||
if (!usageData && finalChunksBuffer.trim()) {
|
||||
logger.debug('🔍 Trying final chunks buffer for usage extraction...')
|
||||
parseSSEForUsage(finalChunksBuffer, true)
|
||||
}
|
||||
|
||||
// 方法3: 从所有解析的事件中重新搜索usage
|
||||
if (!usageData && allParsedEvents.length > 0) {
|
||||
logger.debug('🔍 Searching through all parsed events for usage...')
|
||||
|
||||
// 倒序查找,因为usage通常在最后
|
||||
for (let i = allParsedEvents.length - 1; i >= 0; i--) {
|
||||
const { usageData: foundUsage, actualModel: foundModel } = extractUsageDataRobust(
|
||||
allParsedEvents[i],
|
||||
`final-event-scan-${i}`
|
||||
)
|
||||
if (foundUsage) {
|
||||
usageData = foundUsage
|
||||
if (foundModel) {
|
||||
actualModel = foundModel
|
||||
}
|
||||
logger.debug(`🎯 Usage found in event ${i} during final scan!`)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 方法4: 尝试合并所有事件并搜索
|
||||
if (!usageData && allParsedEvents.length > 0) {
|
||||
logger.debug('🔍 Trying combined events analysis...')
|
||||
const combinedData = {
|
||||
events: allParsedEvents,
|
||||
lastEvent: allParsedEvents[allParsedEvents.length - 1],
|
||||
eventCount: allParsedEvents.length
|
||||
}
|
||||
|
||||
const { usageData: combinedUsage } = extractUsageDataRobust(
|
||||
combinedData,
|
||||
'combined-events'
|
||||
)
|
||||
if (combinedUsage) {
|
||||
usageData = combinedUsage
|
||||
logger.debug('🎯 Usage found via combined events analysis!')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 最终usage状态报告
|
||||
if (usageData) {
|
||||
logger.debug('✅ Final stream usage extraction SUCCESS', {
|
||||
streamId,
|
||||
usageData,
|
||||
actualModel,
|
||||
totalEvents: allParsedEvents.length,
|
||||
finalBufferSize: finalChunksBuffer.length
|
||||
})
|
||||
} else {
|
||||
logger.warn('❌ Final stream usage extraction FAILED', {
|
||||
streamId,
|
||||
totalEvents: allParsedEvents.length,
|
||||
finalBufferSize: finalChunksBuffer.length,
|
||||
mainBufferSize: buffer.length,
|
||||
lastFewEvents: allParsedEvents.slice(-3).map((e) => ({
|
||||
type: e.type,
|
||||
hasUsage: !!e.usage,
|
||||
hasResponse: !!e.response,
|
||||
keys: Object.keys(e)
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
if (onEnd) {
|
||||
@@ -484,6 +600,120 @@ function handleStreamResponse(upstreamResponse, clientResponse, options = {}) {
|
||||
})
|
||||
}
|
||||
|
||||
// 强化的用量数据提取函数
|
||||
function extractUsageDataRobust(responseData, context = 'unknown') {
|
||||
logger.debug(`🔍 Attempting usage extraction for ${context}`, {
|
||||
responseDataKeys: Object.keys(responseData || {}),
|
||||
responseDataType: typeof responseData,
|
||||
hasUsage: !!responseData?.usage,
|
||||
hasResponse: !!responseData?.response
|
||||
})
|
||||
|
||||
let usageData = null
|
||||
let actualModel = null
|
||||
|
||||
try {
|
||||
// 策略 1: 顶层 usage (标准 Chat Completions)
|
||||
if (responseData?.usage) {
|
||||
usageData = responseData.usage
|
||||
actualModel = responseData.model
|
||||
logger.debug('✅ Usage extracted via Strategy 1 (top-level)', { usageData, actualModel })
|
||||
}
|
||||
|
||||
// 策略 2: response.usage (Responses API)
|
||||
else if (responseData?.response?.usage) {
|
||||
usageData = responseData.response.usage
|
||||
actualModel = responseData.response.model || responseData.model
|
||||
logger.debug('✅ Usage extracted via Strategy 2 (response.usage)', { usageData, actualModel })
|
||||
}
|
||||
|
||||
// 策略 3: 嵌套搜索 - 深度查找 usage 字段
|
||||
else {
|
||||
const findUsageRecursive = (obj, path = '') => {
|
||||
if (!obj || typeof obj !== 'object') {
|
||||
return null
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(obj)) {
|
||||
const currentPath = path ? `${path}.${key}` : key
|
||||
|
||||
if (key === 'usage' && value && typeof value === 'object') {
|
||||
logger.debug(`✅ Usage found at path: ${currentPath}`, value)
|
||||
return { usage: value, path: currentPath }
|
||||
}
|
||||
|
||||
if (typeof value === 'object' && value !== null) {
|
||||
const nested = findUsageRecursive(value, currentPath)
|
||||
if (nested) {
|
||||
return nested
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
const found = findUsageRecursive(responseData)
|
||||
if (found) {
|
||||
usageData = found.usage
|
||||
// Try to find model in the same parent object
|
||||
const pathParts = found.path.split('.')
|
||||
pathParts.pop() // remove 'usage'
|
||||
let modelParent = responseData
|
||||
for (const part of pathParts) {
|
||||
modelParent = modelParent?.[part]
|
||||
}
|
||||
actualModel = modelParent?.model || responseData?.model
|
||||
logger.debug('✅ Usage extracted via Strategy 3 (recursive)', {
|
||||
usageData,
|
||||
actualModel,
|
||||
foundPath: found.path
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 策略 4: 特殊响应格式处理
|
||||
if (!usageData) {
|
||||
// 检查是否有 choices 数组,usage 可能在最后一个 choice 中
|
||||
if (responseData?.choices?.length > 0) {
|
||||
const lastChoice = responseData.choices[responseData.choices.length - 1]
|
||||
if (lastChoice?.usage) {
|
||||
usageData = lastChoice.usage
|
||||
actualModel = responseData.model || lastChoice.model
|
||||
logger.debug('✅ Usage extracted via Strategy 4 (choices)', { usageData, actualModel })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 最终验证和记录
|
||||
if (usageData) {
|
||||
logger.debug('🎯 Final usage extraction result', {
|
||||
context,
|
||||
usageData,
|
||||
actualModel,
|
||||
inputTokens: usageData.prompt_tokens || usageData.input_tokens || 0,
|
||||
outputTokens: usageData.completion_tokens || usageData.output_tokens || 0,
|
||||
totalTokens: usageData.total_tokens || 0
|
||||
})
|
||||
} else {
|
||||
logger.warn('❌ Failed to extract usage data', {
|
||||
context,
|
||||
responseDataStructure: `${JSON.stringify(responseData, null, 2).substring(0, 1000)}...`,
|
||||
availableKeys: Object.keys(responseData || {}),
|
||||
responseSize: JSON.stringify(responseData || {}).length
|
||||
})
|
||||
}
|
||||
} catch (extractionError) {
|
||||
logger.error('🚨 Error during usage extraction', {
|
||||
context,
|
||||
error: extractionError.message,
|
||||
stack: extractionError.stack,
|
||||
responseDataType: typeof responseData
|
||||
})
|
||||
}
|
||||
|
||||
return { usageData, actualModel }
|
||||
}
|
||||
|
||||
// 处理非流式响应
|
||||
function handleNonStreamResponse(upstreamResponse, clientResponse) {
|
||||
try {
|
||||
@@ -510,9 +740,8 @@ function handleNonStreamResponse(upstreamResponse, clientResponse) {
|
||||
const responseData = upstreamResponse.data
|
||||
clientResponse.json(responseData)
|
||||
|
||||
// 提取 usage 数据
|
||||
const usageData = responseData.usage
|
||||
const actualModel = responseData.model
|
||||
// 使用强化的用量提取
|
||||
const { usageData, actualModel } = extractUsageDataRobust(responseData, 'non-stream')
|
||||
|
||||
return { usageData, actualModel, responseData }
|
||||
} catch (error) {
|
||||
|
||||
@@ -138,11 +138,19 @@ function createOAuth2Client(redirectUri = null, proxyConfig = null) {
|
||||
return new OAuth2Client(clientOptions)
|
||||
}
|
||||
|
||||
// 生成授权 URL (支持 PKCE)
|
||||
async function generateAuthUrl(state = null, redirectUri = null) {
|
||||
// 生成授权 URL (支持 PKCE 和代理)
|
||||
async function generateAuthUrl(state = null, redirectUri = null, proxyConfig = null) {
|
||||
// 使用新的 redirect URI
|
||||
const finalRedirectUri = redirectUri || 'https://codeassist.google.com/authcode'
|
||||
const oAuth2Client = createOAuth2Client(finalRedirectUri)
|
||||
const oAuth2Client = createOAuth2Client(finalRedirectUri, proxyConfig)
|
||||
|
||||
if (proxyConfig) {
|
||||
logger.info(
|
||||
`🌐 Using proxy for Gemini auth URL generation: ${ProxyHelper.getProxyDescription(proxyConfig)}`
|
||||
)
|
||||
} else {
|
||||
logger.debug('🌐 No proxy configured for Gemini auth URL generation')
|
||||
}
|
||||
|
||||
// 生成 PKCE code verifier
|
||||
const codeVerifier = await oAuth2Client.generateCodeVerifierAsync()
|
||||
@@ -965,12 +973,10 @@ async function getAccountRateLimitInfo(accountId) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取配置的OAuth客户端 - 参考GeminiCliSimulator的getOauthClient方法
|
||||
async function getOauthClient(accessToken, refreshToken) {
|
||||
const client = new OAuth2Client({
|
||||
clientId: OAUTH_CLIENT_ID,
|
||||
clientSecret: OAUTH_CLIENT_SECRET
|
||||
})
|
||||
// 获取配置的OAuth客户端 - 参考GeminiCliSimulator的getOauthClient方法(支持代理)
|
||||
async function getOauthClient(accessToken, refreshToken, proxyConfig = null) {
|
||||
const client = createOAuth2Client(null, proxyConfig)
|
||||
|
||||
const creds = {
|
||||
access_token: accessToken,
|
||||
refresh_token: refreshToken,
|
||||
@@ -980,6 +986,14 @@ async function getOauthClient(accessToken, refreshToken) {
|
||||
expiry_date: 1754269905646
|
||||
}
|
||||
|
||||
if (proxyConfig) {
|
||||
logger.info(
|
||||
`🌐 Using proxy for Gemini OAuth client: ${ProxyHelper.getProxyDescription(proxyConfig)}`
|
||||
)
|
||||
} else {
|
||||
logger.debug('🌐 No proxy configured for Gemini OAuth client')
|
||||
}
|
||||
|
||||
// 设置凭据
|
||||
client.setCredentials(creds)
|
||||
|
||||
@@ -996,8 +1010,8 @@ async function getOauthClient(accessToken, refreshToken) {
|
||||
return client
|
||||
}
|
||||
|
||||
// 调用 Google Code Assist API 的 loadCodeAssist 方法
|
||||
async function loadCodeAssist(client, projectId = null) {
|
||||
// 调用 Google Code Assist API 的 loadCodeAssist 方法(支持代理)
|
||||
async function loadCodeAssist(client, projectId = null, proxyConfig = null) {
|
||||
const axios = require('axios')
|
||||
const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'
|
||||
const CODE_ASSIST_API_VERSION = 'v1internal'
|
||||
@@ -1017,7 +1031,7 @@ async function loadCodeAssist(client, projectId = null) {
|
||||
metadata: clientMetadata
|
||||
}
|
||||
|
||||
const response = await axios({
|
||||
const axiosConfig = {
|
||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:loadCodeAssist`,
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -1026,7 +1040,20 @@ async function loadCodeAssist(client, projectId = null) {
|
||||
},
|
||||
data: request,
|
||||
timeout: 30000
|
||||
})
|
||||
}
|
||||
|
||||
// 添加代理配置
|
||||
const proxyAgent = ProxyHelper.createProxyAgent(proxyConfig)
|
||||
if (proxyAgent) {
|
||||
axiosConfig.httpsAgent = proxyAgent
|
||||
logger.info(
|
||||
`🌐 Using proxy for Gemini loadCodeAssist: ${ProxyHelper.getProxyDescription(proxyConfig)}`
|
||||
)
|
||||
} else {
|
||||
logger.debug('🌐 No proxy configured for Gemini loadCodeAssist')
|
||||
}
|
||||
|
||||
const response = await axios(axiosConfig)
|
||||
|
||||
logger.info('📋 loadCodeAssist API调用成功')
|
||||
return response.data
|
||||
@@ -1059,8 +1086,8 @@ function getOnboardTier(loadRes) {
|
||||
}
|
||||
}
|
||||
|
||||
// 调用 Google Code Assist API 的 onboardUser 方法(包含轮询逻辑)
|
||||
async function onboardUser(client, tierId, projectId, clientMetadata) {
|
||||
// 调用 Google Code Assist API 的 onboardUser 方法(包含轮询逻辑,支持代理)
|
||||
async function onboardUser(client, tierId, projectId, clientMetadata, proxyConfig = null) {
|
||||
const axios = require('axios')
|
||||
const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'
|
||||
const CODE_ASSIST_API_VERSION = 'v1internal'
|
||||
@@ -1073,15 +1100,8 @@ async function onboardUser(client, tierId, projectId, clientMetadata) {
|
||||
metadata: clientMetadata
|
||||
}
|
||||
|
||||
logger.info('📋 开始onboardUser API调用', {
|
||||
tierId,
|
||||
projectId,
|
||||
hasProjectId: !!projectId,
|
||||
isFreeTier: tierId === 'free-tier' || tierId === 'FREE'
|
||||
})
|
||||
|
||||
// 轮询onboardUser直到长运行操作完成
|
||||
let lroRes = await axios({
|
||||
// 创建基础axios配置
|
||||
const baseAxiosConfig = {
|
||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:onboardUser`,
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -1090,8 +1110,29 @@ async function onboardUser(client, tierId, projectId, clientMetadata) {
|
||||
},
|
||||
data: onboardReq,
|
||||
timeout: 30000
|
||||
}
|
||||
|
||||
// 添加代理配置
|
||||
const proxyAgent = ProxyHelper.createProxyAgent(proxyConfig)
|
||||
if (proxyAgent) {
|
||||
baseAxiosConfig.httpsAgent = proxyAgent
|
||||
logger.info(
|
||||
`🌐 Using proxy for Gemini onboardUser: ${ProxyHelper.getProxyDescription(proxyConfig)}`
|
||||
)
|
||||
} else {
|
||||
logger.debug('🌐 No proxy configured for Gemini onboardUser')
|
||||
}
|
||||
|
||||
logger.info('📋 开始onboardUser API调用', {
|
||||
tierId,
|
||||
projectId,
|
||||
hasProjectId: !!projectId,
|
||||
isFreeTier: tierId === 'free-tier' || tierId === 'FREE'
|
||||
})
|
||||
|
||||
// 轮询onboardUser直到长运行操作完成
|
||||
let lroRes = await axios(baseAxiosConfig)
|
||||
|
||||
let attempts = 0
|
||||
const maxAttempts = 12 // 最多等待1分钟(5秒 * 12次)
|
||||
|
||||
@@ -1099,17 +1140,7 @@ async function onboardUser(client, tierId, projectId, clientMetadata) {
|
||||
logger.info(`⏳ 等待onboardUser完成... (${attempts + 1}/${maxAttempts})`)
|
||||
await new Promise((resolve) => setTimeout(resolve, 5000))
|
||||
|
||||
lroRes = await axios({
|
||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:onboardUser`,
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
data: onboardReq,
|
||||
timeout: 30000
|
||||
})
|
||||
|
||||
lroRes = await axios(baseAxiosConfig)
|
||||
attempts++
|
||||
}
|
||||
|
||||
@@ -1121,8 +1152,13 @@ async function onboardUser(client, tierId, projectId, clientMetadata) {
|
||||
return lroRes.data
|
||||
}
|
||||
|
||||
// 完整的用户设置流程 - 参考setup.ts的逻辑
|
||||
async function setupUser(client, initialProjectId = null, clientMetadata = null) {
|
||||
// 完整的用户设置流程 - 参考setup.ts的逻辑(支持代理)
|
||||
async function setupUser(
|
||||
client,
|
||||
initialProjectId = null,
|
||||
clientMetadata = null,
|
||||
proxyConfig = null
|
||||
) {
|
||||
logger.info('🚀 setupUser 开始', { initialProjectId, hasClientMetadata: !!clientMetadata })
|
||||
|
||||
let projectId = initialProjectId || process.env.GOOGLE_CLOUD_PROJECT || null
|
||||
@@ -1141,7 +1177,7 @@ async function setupUser(client, initialProjectId = null, clientMetadata = null)
|
||||
|
||||
// 调用loadCodeAssist
|
||||
logger.info('📞 调用 loadCodeAssist...')
|
||||
const loadRes = await loadCodeAssist(client, projectId)
|
||||
const loadRes = await loadCodeAssist(client, projectId, proxyConfig)
|
||||
logger.info('✅ loadCodeAssist 完成', {
|
||||
hasCloudaicompanionProject: !!loadRes.cloudaicompanionProject
|
||||
})
|
||||
@@ -1164,7 +1200,7 @@ async function setupUser(client, initialProjectId = null, clientMetadata = null)
|
||||
|
||||
// 调用onboardUser
|
||||
logger.info('📞 调用 onboardUser...', { tierId: tier.id, projectId })
|
||||
const lroRes = await onboardUser(client, tier.id, projectId, clientMetadata)
|
||||
const lroRes = await onboardUser(client, tier.id, projectId, clientMetadata, proxyConfig)
|
||||
logger.info('✅ onboardUser 完成', { hasDone: !!lroRes.done, hasResponse: !!lroRes.response })
|
||||
|
||||
const result = {
|
||||
@@ -1178,8 +1214,8 @@ async function setupUser(client, initialProjectId = null, clientMetadata = null)
|
||||
return result
|
||||
}
|
||||
|
||||
// 调用 Code Assist API 计算 token 数量
|
||||
async function countTokens(client, contents, model = 'gemini-2.0-flash-exp') {
|
||||
// 调用 Code Assist API 计算 token 数量(支持代理)
|
||||
async function countTokens(client, contents, model = 'gemini-2.0-flash-exp', proxyConfig = null) {
|
||||
const axios = require('axios')
|
||||
const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'
|
||||
const CODE_ASSIST_API_VERSION = 'v1internal'
|
||||
@@ -1196,7 +1232,7 @@ async function countTokens(client, contents, model = 'gemini-2.0-flash-exp') {
|
||||
|
||||
logger.info('📊 countTokens API调用开始', { model, contentsLength: contents.length })
|
||||
|
||||
const response = await axios({
|
||||
const axiosConfig = {
|
||||
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:countTokens`,
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -1205,7 +1241,20 @@ async function countTokens(client, contents, model = 'gemini-2.0-flash-exp') {
|
||||
},
|
||||
data: request,
|
||||
timeout: 30000
|
||||
})
|
||||
}
|
||||
|
||||
// 添加代理配置
|
||||
const proxyAgent = ProxyHelper.createProxyAgent(proxyConfig)
|
||||
if (proxyAgent) {
|
||||
axiosConfig.httpsAgent = proxyAgent
|
||||
logger.info(
|
||||
`🌐 Using proxy for Gemini countTokens: ${ProxyHelper.getProxyDescription(proxyConfig)}`
|
||||
)
|
||||
} else {
|
||||
logger.debug('🌐 No proxy configured for Gemini countTokens')
|
||||
}
|
||||
|
||||
const response = await axios(axiosConfig)
|
||||
|
||||
logger.info('✅ countTokens API调用成功', { totalTokens: response.data.totalTokens })
|
||||
return response.data
|
||||
|
||||
591
src/services/ldapService.js
Normal file
591
src/services/ldapService.js
Normal file
@@ -0,0 +1,591 @@
|
||||
const ldap = require('ldapjs')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
const userService = require('./userService')
|
||||
|
||||
class LdapService {
|
||||
constructor() {
|
||||
this.config = config.ldap
|
||||
this.client = null
|
||||
|
||||
// 验证配置
|
||||
if (this.config.enabled) {
|
||||
this.validateConfiguration()
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 验证LDAP配置
|
||||
validateConfiguration() {
|
||||
const errors = []
|
||||
|
||||
if (!this.config.server) {
|
||||
errors.push('LDAP server configuration is missing')
|
||||
} else {
|
||||
if (!this.config.server.url || typeof this.config.server.url !== 'string') {
|
||||
errors.push('LDAP server URL is not configured or invalid')
|
||||
}
|
||||
|
||||
if (!this.config.server.bindDN || typeof this.config.server.bindDN !== 'string') {
|
||||
errors.push('LDAP bind DN is not configured or invalid')
|
||||
}
|
||||
|
||||
if (
|
||||
!this.config.server.bindCredentials ||
|
||||
typeof this.config.server.bindCredentials !== 'string'
|
||||
) {
|
||||
errors.push('LDAP bind credentials are not configured or invalid')
|
||||
}
|
||||
|
||||
if (!this.config.server.searchBase || typeof this.config.server.searchBase !== 'string') {
|
||||
errors.push('LDAP search base is not configured or invalid')
|
||||
}
|
||||
|
||||
if (!this.config.server.searchFilter || typeof this.config.server.searchFilter !== 'string') {
|
||||
errors.push('LDAP search filter is not configured or invalid')
|
||||
}
|
||||
}
|
||||
|
||||
if (errors.length > 0) {
|
||||
logger.error('❌ LDAP configuration validation failed:', errors)
|
||||
// Don't throw error during initialization, just log warnings
|
||||
logger.warn('⚠️ LDAP authentication may not work properly due to configuration errors')
|
||||
} else {
|
||||
logger.info('✅ LDAP configuration validation passed')
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 提取LDAP条目的DN
|
||||
extractDN(ldapEntry) {
|
||||
if (!ldapEntry) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Try different ways to get the DN
|
||||
let dn = null
|
||||
|
||||
// Method 1: Direct dn property
|
||||
if (ldapEntry.dn) {
|
||||
;({ dn } = ldapEntry)
|
||||
}
|
||||
// Method 2: objectName property (common in some LDAP implementations)
|
||||
else if (ldapEntry.objectName) {
|
||||
dn = ldapEntry.objectName
|
||||
}
|
||||
// Method 3: distinguishedName property
|
||||
else if (ldapEntry.distinguishedName) {
|
||||
dn = ldapEntry.distinguishedName
|
||||
}
|
||||
// Method 4: Check if the entry itself is a DN string
|
||||
else if (typeof ldapEntry === 'string' && ldapEntry.includes('=')) {
|
||||
dn = ldapEntry
|
||||
}
|
||||
|
||||
// Convert DN to string if it's an object
|
||||
if (dn && typeof dn === 'object') {
|
||||
if (dn.toString && typeof dn.toString === 'function') {
|
||||
dn = dn.toString()
|
||||
} else if (dn.dn && typeof dn.dn === 'string') {
|
||||
;({ dn } = dn)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the DN format
|
||||
if (typeof dn === 'string' && dn.trim() !== '' && dn.includes('=')) {
|
||||
return dn.trim()
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
// 🔗 创建LDAP客户端连接
|
||||
createClient() {
|
||||
try {
|
||||
const clientOptions = {
|
||||
url: this.config.server.url,
|
||||
timeout: this.config.server.timeout,
|
||||
connectTimeout: this.config.server.connectTimeout,
|
||||
reconnect: true
|
||||
}
|
||||
|
||||
// 如果使用 LDAPS (SSL/TLS),添加 TLS 选项
|
||||
if (this.config.server.url.toLowerCase().startsWith('ldaps://')) {
|
||||
const tlsOptions = {}
|
||||
|
||||
// 证书验证设置
|
||||
if (this.config.server.tls) {
|
||||
if (typeof this.config.server.tls.rejectUnauthorized === 'boolean') {
|
||||
tlsOptions.rejectUnauthorized = this.config.server.tls.rejectUnauthorized
|
||||
}
|
||||
|
||||
// CA 证书
|
||||
if (this.config.server.tls.ca) {
|
||||
tlsOptions.ca = this.config.server.tls.ca
|
||||
}
|
||||
|
||||
// 客户端证书和私钥 (双向认证)
|
||||
if (this.config.server.tls.cert) {
|
||||
tlsOptions.cert = this.config.server.tls.cert
|
||||
}
|
||||
|
||||
if (this.config.server.tls.key) {
|
||||
tlsOptions.key = this.config.server.tls.key
|
||||
}
|
||||
|
||||
// 服务器名称 (SNI)
|
||||
if (this.config.server.tls.servername) {
|
||||
tlsOptions.servername = this.config.server.tls.servername
|
||||
}
|
||||
}
|
||||
|
||||
clientOptions.tlsOptions = tlsOptions
|
||||
|
||||
logger.debug('🔒 Creating LDAPS client with TLS options:', {
|
||||
url: this.config.server.url,
|
||||
rejectUnauthorized: tlsOptions.rejectUnauthorized,
|
||||
hasCA: !!tlsOptions.ca,
|
||||
hasCert: !!tlsOptions.cert,
|
||||
hasKey: !!tlsOptions.key,
|
||||
servername: tlsOptions.servername
|
||||
})
|
||||
}
|
||||
|
||||
const client = ldap.createClient(clientOptions)
|
||||
|
||||
// 设置错误处理
|
||||
client.on('error', (err) => {
|
||||
if (err.code === 'CERT_HAS_EXPIRED' || err.code === 'UNABLE_TO_VERIFY_LEAF_SIGNATURE') {
|
||||
logger.error('🔒 LDAP TLS certificate error:', {
|
||||
code: err.code,
|
||||
message: err.message,
|
||||
hint: 'Consider setting LDAP_TLS_REJECT_UNAUTHORIZED=false for self-signed certificates'
|
||||
})
|
||||
} else {
|
||||
logger.error('🔌 LDAP client error:', err)
|
||||
}
|
||||
})
|
||||
|
||||
client.on('connect', () => {
|
||||
if (this.config.server.url.toLowerCase().startsWith('ldaps://')) {
|
||||
logger.info('🔒 LDAPS client connected successfully')
|
||||
} else {
|
||||
logger.info('🔗 LDAP client connected successfully')
|
||||
}
|
||||
})
|
||||
|
||||
client.on('connectTimeout', () => {
|
||||
logger.warn('⏱️ LDAP connection timeout')
|
||||
})
|
||||
|
||||
return client
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to create LDAP client:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔒 绑定LDAP连接(管理员认证)
|
||||
async bindClient(client) {
|
||||
return new Promise((resolve, reject) => {
|
||||
// 验证绑定凭据
|
||||
const { bindDN } = this.config.server
|
||||
const { bindCredentials } = this.config.server
|
||||
|
||||
if (!bindDN || typeof bindDN !== 'string') {
|
||||
const error = new Error('LDAP bind DN is not configured or invalid')
|
||||
logger.error('❌ LDAP configuration error:', error.message)
|
||||
reject(error)
|
||||
return
|
||||
}
|
||||
|
||||
if (!bindCredentials || typeof bindCredentials !== 'string') {
|
||||
const error = new Error('LDAP bind credentials are not configured or invalid')
|
||||
logger.error('❌ LDAP configuration error:', error.message)
|
||||
reject(error)
|
||||
return
|
||||
}
|
||||
|
||||
client.bind(bindDN, bindCredentials, (err) => {
|
||||
if (err) {
|
||||
logger.error('❌ LDAP bind failed:', err)
|
||||
reject(err)
|
||||
} else {
|
||||
logger.debug('🔑 LDAP bind successful')
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 🔍 搜索用户
|
||||
async searchUser(client, username) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const searchFilter = this.config.server.searchFilter.replace('{{username}}', username)
|
||||
const searchOptions = {
|
||||
scope: 'sub',
|
||||
filter: searchFilter,
|
||||
attributes: this.config.server.searchAttributes
|
||||
}
|
||||
|
||||
logger.debug(`🔍 Searching for user: ${username} with filter: ${searchFilter}`)
|
||||
|
||||
const entries = []
|
||||
|
||||
client.search(this.config.server.searchBase, searchOptions, (err, res) => {
|
||||
if (err) {
|
||||
logger.error('❌ LDAP search error:', err)
|
||||
reject(err)
|
||||
return
|
||||
}
|
||||
|
||||
res.on('searchEntry', (entry) => {
|
||||
logger.debug('🔍 LDAP search entry received:', {
|
||||
dn: entry.dn,
|
||||
objectName: entry.objectName,
|
||||
type: typeof entry.dn,
|
||||
entryType: typeof entry,
|
||||
hasAttributes: !!entry.attributes,
|
||||
attributeCount: entry.attributes ? entry.attributes.length : 0
|
||||
})
|
||||
entries.push(entry)
|
||||
})
|
||||
|
||||
res.on('searchReference', (referral) => {
|
||||
logger.debug('🔗 LDAP search referral:', referral.uris)
|
||||
})
|
||||
|
||||
res.on('error', (error) => {
|
||||
logger.error('❌ LDAP search result error:', error)
|
||||
reject(error)
|
||||
})
|
||||
|
||||
res.on('end', (result) => {
|
||||
logger.debug(
|
||||
`✅ LDAP search completed. Status: ${result.status}, Found ${entries.length} entries`
|
||||
)
|
||||
|
||||
if (entries.length === 0) {
|
||||
resolve(null)
|
||||
} else {
|
||||
// Log the structure of the first entry for debugging
|
||||
if (entries[0]) {
|
||||
logger.debug('🔍 Full LDAP entry structure:', {
|
||||
entryType: typeof entries[0],
|
||||
entryConstructor: entries[0].constructor?.name,
|
||||
entryKeys: Object.keys(entries[0]),
|
||||
entryStringified: JSON.stringify(entries[0], null, 2).substring(0, 500)
|
||||
})
|
||||
}
|
||||
|
||||
if (entries.length === 1) {
|
||||
resolve(entries[0])
|
||||
} else {
|
||||
logger.warn(`⚠️ Multiple LDAP entries found for username: ${username}`)
|
||||
resolve(entries[0]) // 使用第一个结果
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 🔐 验证用户密码
|
||||
async authenticateUser(userDN, password) {
|
||||
return new Promise((resolve, reject) => {
|
||||
// 验证输入参数
|
||||
if (!userDN || typeof userDN !== 'string') {
|
||||
const error = new Error('User DN is not provided or invalid')
|
||||
logger.error('❌ LDAP authentication error:', error.message)
|
||||
reject(error)
|
||||
return
|
||||
}
|
||||
|
||||
if (!password || typeof password !== 'string') {
|
||||
logger.debug(`🚫 Invalid or empty password for DN: ${userDN}`)
|
||||
resolve(false)
|
||||
return
|
||||
}
|
||||
|
||||
const authClient = this.createClient()
|
||||
|
||||
authClient.bind(userDN, password, (err) => {
|
||||
authClient.unbind() // 立即关闭认证客户端
|
||||
|
||||
if (err) {
|
||||
if (err.name === 'InvalidCredentialsError') {
|
||||
logger.debug(`🚫 Invalid credentials for DN: ${userDN}`)
|
||||
resolve(false)
|
||||
} else {
|
||||
logger.error('❌ LDAP authentication error:', err)
|
||||
reject(err)
|
||||
}
|
||||
} else {
|
||||
logger.debug(`✅ Authentication successful for DN: ${userDN}`)
|
||||
resolve(true)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 📝 提取用户信息
|
||||
extractUserInfo(ldapEntry, username) {
|
||||
try {
|
||||
const attributes = ldapEntry.attributes || []
|
||||
const userInfo = { username }
|
||||
|
||||
// 创建属性映射
|
||||
const attrMap = {}
|
||||
attributes.forEach((attr) => {
|
||||
const name = attr.type || attr.name
|
||||
const values = Array.isArray(attr.values) ? attr.values : [attr.values]
|
||||
attrMap[name] = values.length === 1 ? values[0] : values
|
||||
})
|
||||
|
||||
// 根据配置映射用户属性
|
||||
const mapping = this.config.userMapping
|
||||
|
||||
userInfo.displayName = attrMap[mapping.displayName] || username
|
||||
userInfo.email = attrMap[mapping.email] || ''
|
||||
userInfo.firstName = attrMap[mapping.firstName] || ''
|
||||
userInfo.lastName = attrMap[mapping.lastName] || ''
|
||||
|
||||
// 如果没有displayName,尝试组合firstName和lastName
|
||||
if (!userInfo.displayName || userInfo.displayName === username) {
|
||||
if (userInfo.firstName || userInfo.lastName) {
|
||||
userInfo.displayName = `${userInfo.firstName || ''} ${userInfo.lastName || ''}`.trim()
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug('📋 Extracted user info:', {
|
||||
username: userInfo.username,
|
||||
displayName: userInfo.displayName,
|
||||
email: userInfo.email
|
||||
})
|
||||
|
||||
return userInfo
|
||||
} catch (error) {
|
||||
logger.error('❌ Error extracting user info:', error)
|
||||
return { username }
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 验证和清理用户名
|
||||
validateAndSanitizeUsername(username) {
|
||||
if (!username || typeof username !== 'string' || username.trim() === '') {
|
||||
throw new Error('Username is required and must be a non-empty string')
|
||||
}
|
||||
|
||||
const trimmedUsername = username.trim()
|
||||
|
||||
// 用户名只能包含字母、数字、下划线和连字符
|
||||
const usernameRegex = /^[a-zA-Z0-9_-]+$/
|
||||
if (!usernameRegex.test(trimmedUsername)) {
|
||||
throw new Error('Username can only contain letters, numbers, underscores, and hyphens')
|
||||
}
|
||||
|
||||
// 长度限制 (防止过长的输入)
|
||||
if (trimmedUsername.length > 64) {
|
||||
throw new Error('Username cannot exceed 64 characters')
|
||||
}
|
||||
|
||||
// 不能以连字符开头或结尾
|
||||
if (trimmedUsername.startsWith('-') || trimmedUsername.endsWith('-')) {
|
||||
throw new Error('Username cannot start or end with a hyphen')
|
||||
}
|
||||
|
||||
return trimmedUsername
|
||||
}
|
||||
|
||||
// 🔐 主要的登录验证方法
|
||||
async authenticateUserCredentials(username, password) {
|
||||
if (!this.config.enabled) {
|
||||
throw new Error('LDAP authentication is not enabled')
|
||||
}
|
||||
|
||||
// 验证和清理用户名 (防止LDAP注入)
|
||||
const sanitizedUsername = this.validateAndSanitizeUsername(username)
|
||||
|
||||
if (!password || typeof password !== 'string' || password.trim() === '') {
|
||||
throw new Error('Password is required and must be a non-empty string')
|
||||
}
|
||||
|
||||
// 验证LDAP服务器配置
|
||||
if (!this.config.server || !this.config.server.url) {
|
||||
throw new Error('LDAP server URL is not configured')
|
||||
}
|
||||
|
||||
if (!this.config.server.bindDN || typeof this.config.server.bindDN !== 'string') {
|
||||
throw new Error('LDAP bind DN is not configured')
|
||||
}
|
||||
|
||||
if (
|
||||
!this.config.server.bindCredentials ||
|
||||
typeof this.config.server.bindCredentials !== 'string'
|
||||
) {
|
||||
throw new Error('LDAP bind credentials are not configured')
|
||||
}
|
||||
|
||||
if (!this.config.server.searchBase || typeof this.config.server.searchBase !== 'string') {
|
||||
throw new Error('LDAP search base is not configured')
|
||||
}
|
||||
|
||||
const client = this.createClient()
|
||||
|
||||
try {
|
||||
// 1. 使用管理员凭据绑定
|
||||
await this.bindClient(client)
|
||||
|
||||
// 2. 搜索用户 (使用已验证的用户名)
|
||||
const ldapEntry = await this.searchUser(client, sanitizedUsername)
|
||||
if (!ldapEntry) {
|
||||
logger.info(`🚫 User not found in LDAP: ${sanitizedUsername}`)
|
||||
return { success: false, message: 'Invalid username or password' }
|
||||
}
|
||||
|
||||
// 3. 获取用户DN
|
||||
logger.debug('🔍 LDAP entry details for DN extraction:', {
|
||||
hasEntry: !!ldapEntry,
|
||||
entryType: typeof ldapEntry,
|
||||
entryKeys: Object.keys(ldapEntry || {}),
|
||||
dn: ldapEntry.dn,
|
||||
objectName: ldapEntry.objectName,
|
||||
dnType: typeof ldapEntry.dn,
|
||||
objectNameType: typeof ldapEntry.objectName
|
||||
})
|
||||
|
||||
// Use the helper method to extract DN
|
||||
const userDN = this.extractDN(ldapEntry)
|
||||
|
||||
logger.debug(`👤 Extracted user DN: ${userDN} (type: ${typeof userDN})`)
|
||||
|
||||
// 验证用户DN
|
||||
if (!userDN) {
|
||||
logger.error(`❌ Invalid or missing DN for user: ${sanitizedUsername}`, {
|
||||
ldapEntryDn: ldapEntry.dn,
|
||||
ldapEntryObjectName: ldapEntry.objectName,
|
||||
ldapEntryType: typeof ldapEntry,
|
||||
extractedDN: userDN
|
||||
})
|
||||
return { success: false, message: 'Authentication service error' }
|
||||
}
|
||||
|
||||
// 4. 验证用户密码
|
||||
const isPasswordValid = await this.authenticateUser(userDN, password)
|
||||
if (!isPasswordValid) {
|
||||
logger.info(`🚫 Invalid password for user: ${sanitizedUsername}`)
|
||||
return { success: false, message: 'Invalid username or password' }
|
||||
}
|
||||
|
||||
// 5. 提取用户信息
|
||||
const userInfo = this.extractUserInfo(ldapEntry, sanitizedUsername)
|
||||
|
||||
// 6. 创建或更新本地用户
|
||||
const user = await userService.createOrUpdateUser(userInfo)
|
||||
|
||||
// 7. 检查用户是否被禁用
|
||||
if (!user.isActive) {
|
||||
logger.security(
|
||||
`🔒 Disabled user LDAP login attempt: ${sanitizedUsername} from LDAP authentication`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
message: 'Your account has been disabled. Please contact administrator.'
|
||||
}
|
||||
}
|
||||
|
||||
// 8. 记录登录
|
||||
await userService.recordUserLogin(user.id)
|
||||
|
||||
// 9. 创建用户会话
|
||||
const sessionToken = await userService.createUserSession(user.id)
|
||||
|
||||
logger.info(`✅ LDAP authentication successful for user: ${sanitizedUsername}`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
user,
|
||||
sessionToken,
|
||||
message: 'Authentication successful'
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ LDAP authentication error:', error)
|
||||
return {
|
||||
success: false,
|
||||
message: 'Authentication service unavailable'
|
||||
}
|
||||
} finally {
|
||||
// 确保客户端连接被关闭
|
||||
if (client) {
|
||||
client.unbind((err) => {
|
||||
if (err) {
|
||||
logger.debug('Error unbinding LDAP client:', err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 测试LDAP连接
|
||||
async testConnection() {
|
||||
if (!this.config.enabled) {
|
||||
return { success: false, message: 'LDAP is not enabled' }
|
||||
}
|
||||
|
||||
const client = this.createClient()
|
||||
|
||||
try {
|
||||
await this.bindClient(client)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
message: 'LDAP connection successful',
|
||||
server: this.config.server.url,
|
||||
searchBase: this.config.server.searchBase
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ LDAP connection test failed:', error)
|
||||
return {
|
||||
success: false,
|
||||
message: `LDAP connection failed: ${error.message}`,
|
||||
server: this.config.server.url
|
||||
}
|
||||
} finally {
|
||||
if (client) {
|
||||
client.unbind((err) => {
|
||||
if (err) {
|
||||
logger.debug('Error unbinding test LDAP client:', err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 获取LDAP配置信息(不包含敏感信息)
|
||||
getConfigInfo() {
|
||||
const configInfo = {
|
||||
enabled: this.config.enabled,
|
||||
server: {
|
||||
url: this.config.server.url,
|
||||
searchBase: this.config.server.searchBase,
|
||||
searchFilter: this.config.server.searchFilter,
|
||||
timeout: this.config.server.timeout,
|
||||
connectTimeout: this.config.server.connectTimeout
|
||||
},
|
||||
userMapping: this.config.userMapping
|
||||
}
|
||||
|
||||
// 添加 TLS 配置信息(不包含敏感数据)
|
||||
if (this.config.server.url.toLowerCase().startsWith('ldaps://') && this.config.server.tls) {
|
||||
configInfo.server.tls = {
|
||||
rejectUnauthorized: this.config.server.tls.rejectUnauthorized,
|
||||
hasCA: !!this.config.server.tls.ca,
|
||||
hasCert: !!this.config.server.tls.cert,
|
||||
hasKey: !!this.config.server.tls.key,
|
||||
servername: this.config.server.tls.servername
|
||||
}
|
||||
}
|
||||
|
||||
return configInfo
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new LdapService()
|
||||
514
src/services/userService.js
Normal file
514
src/services/userService.js
Normal file
@@ -0,0 +1,514 @@
|
||||
const redis = require('../models/redis')
|
||||
const crypto = require('crypto')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
|
||||
class UserService {
|
||||
constructor() {
|
||||
this.userPrefix = 'user:'
|
||||
this.usernamePrefix = 'username:'
|
||||
this.userSessionPrefix = 'user_session:'
|
||||
}
|
||||
|
||||
// 🔑 生成用户ID
|
||||
generateUserId() {
|
||||
return crypto.randomBytes(16).toString('hex')
|
||||
}
|
||||
|
||||
// 🔑 生成会话Token
|
||||
generateSessionToken() {
|
||||
return crypto.randomBytes(32).toString('hex')
|
||||
}
|
||||
|
||||
// 👤 创建或更新用户
|
||||
async createOrUpdateUser(userData) {
|
||||
try {
|
||||
const {
|
||||
username,
|
||||
email,
|
||||
displayName,
|
||||
firstName,
|
||||
lastName,
|
||||
role = config.userManagement.defaultUserRole,
|
||||
isActive = true
|
||||
} = userData
|
||||
|
||||
// 检查用户是否已存在
|
||||
let user = await this.getUserByUsername(username)
|
||||
const isNewUser = !user
|
||||
|
||||
if (isNewUser) {
|
||||
const userId = this.generateUserId()
|
||||
user = {
|
||||
id: userId,
|
||||
username,
|
||||
email,
|
||||
displayName,
|
||||
firstName,
|
||||
lastName,
|
||||
role,
|
||||
isActive,
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
lastLoginAt: null,
|
||||
apiKeyCount: 0,
|
||||
totalUsage: {
|
||||
requests: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalCost: 0
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 更新现有用户信息
|
||||
user = {
|
||||
...user,
|
||||
email,
|
||||
displayName,
|
||||
firstName,
|
||||
lastName,
|
||||
updatedAt: new Date().toISOString()
|
||||
}
|
||||
}
|
||||
|
||||
// 保存用户信息
|
||||
await redis.set(`${this.userPrefix}${user.id}`, JSON.stringify(user))
|
||||
await redis.set(`${this.usernamePrefix}${username}`, user.id)
|
||||
|
||||
logger.info(`📝 ${isNewUser ? 'Created' : 'Updated'} user: ${username} (${user.id})`)
|
||||
return user
|
||||
} catch (error) {
|
||||
logger.error('❌ Error creating/updating user:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 👤 通过用户名获取用户
|
||||
async getUserByUsername(username) {
|
||||
try {
|
||||
const userId = await redis.get(`${this.usernamePrefix}${username}`)
|
||||
if (!userId) {
|
||||
return null
|
||||
}
|
||||
|
||||
const userData = await redis.get(`${this.userPrefix}${userId}`)
|
||||
return userData ? JSON.parse(userData) : null
|
||||
} catch (error) {
|
||||
logger.error('❌ Error getting user by username:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 👤 通过ID获取用户
|
||||
async getUserById(userId, calculateUsage = true) {
|
||||
try {
|
||||
const userData = await redis.get(`${this.userPrefix}${userId}`)
|
||||
if (!userData) {
|
||||
return null
|
||||
}
|
||||
|
||||
const user = JSON.parse(userData)
|
||||
|
||||
// Calculate totalUsage by aggregating user's API keys usage (if requested)
|
||||
if (calculateUsage) {
|
||||
try {
|
||||
const usageStats = await this.calculateUserUsageStats(userId)
|
||||
user.totalUsage = usageStats.totalUsage
|
||||
user.apiKeyCount = usageStats.apiKeyCount
|
||||
} catch (error) {
|
||||
logger.error('❌ Error calculating user usage stats:', error)
|
||||
// Fallback to stored values if calculation fails
|
||||
user.totalUsage = user.totalUsage || {
|
||||
requests: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalCost: 0
|
||||
}
|
||||
user.apiKeyCount = user.apiKeyCount || 0
|
||||
}
|
||||
}
|
||||
|
||||
return user
|
||||
} catch (error) {
|
||||
logger.error('❌ Error getting user by ID:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 计算用户使用统计(通过聚合API Keys)
|
||||
async calculateUserUsageStats(userId) {
|
||||
try {
|
||||
// Use the existing apiKeyService method which already includes usage stats
|
||||
const apiKeyService = require('./apiKeyService')
|
||||
const userApiKeys = await apiKeyService.getUserApiKeys(userId, true) // Include deleted keys for stats
|
||||
|
||||
const totalUsage = {
|
||||
requests: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalCost: 0
|
||||
}
|
||||
|
||||
for (const apiKey of userApiKeys) {
|
||||
if (apiKey.usage && apiKey.usage.total) {
|
||||
totalUsage.requests += apiKey.usage.total.requests || 0
|
||||
totalUsage.inputTokens += apiKey.usage.total.inputTokens || 0
|
||||
totalUsage.outputTokens += apiKey.usage.total.outputTokens || 0
|
||||
totalUsage.totalCost += apiKey.totalCost || 0
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`📊 Calculated user ${userId} usage: ${totalUsage.requests} requests, ${totalUsage.inputTokens} input tokens, $${totalUsage.totalCost.toFixed(4)} total cost from ${userApiKeys.length} API keys`
|
||||
)
|
||||
|
||||
// Count only non-deleted API keys for the user's active count
|
||||
const activeApiKeyCount = userApiKeys.filter((key) => key.isDeleted !== 'true').length
|
||||
|
||||
return {
|
||||
totalUsage,
|
||||
apiKeyCount: activeApiKeyCount
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Error calculating user usage stats:', error)
|
||||
return {
|
||||
totalUsage: {
|
||||
requests: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalCost: 0
|
||||
},
|
||||
apiKeyCount: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 📋 获取所有用户列表(管理员功能)
|
||||
async getAllUsers(options = {}) {
|
||||
try {
|
||||
const client = redis.getClientSafe()
|
||||
const { page = 1, limit = 20, role, isActive } = options
|
||||
const pattern = `${this.userPrefix}*`
|
||||
const keys = await client.keys(pattern)
|
||||
|
||||
const users = []
|
||||
for (const key of keys) {
|
||||
const userData = await client.get(key)
|
||||
if (userData) {
|
||||
const user = JSON.parse(userData)
|
||||
|
||||
// 应用过滤条件
|
||||
if (role && user.role !== role) {
|
||||
continue
|
||||
}
|
||||
if (typeof isActive === 'boolean' && user.isActive !== isActive) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate dynamic usage stats for each user
|
||||
try {
|
||||
const usageStats = await this.calculateUserUsageStats(user.id)
|
||||
user.totalUsage = usageStats.totalUsage
|
||||
user.apiKeyCount = usageStats.apiKeyCount
|
||||
} catch (error) {
|
||||
logger.error(`❌ Error calculating usage for user ${user.id}:`, error)
|
||||
// Fallback to stored values
|
||||
user.totalUsage = user.totalUsage || {
|
||||
requests: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalCost: 0
|
||||
}
|
||||
user.apiKeyCount = user.apiKeyCount || 0
|
||||
}
|
||||
|
||||
users.push(user)
|
||||
}
|
||||
}
|
||||
|
||||
// 排序和分页
|
||||
users.sort((a, b) => new Date(b.createdAt) - new Date(a.createdAt))
|
||||
const startIndex = (page - 1) * limit
|
||||
const endIndex = startIndex + limit
|
||||
const paginatedUsers = users.slice(startIndex, endIndex)
|
||||
|
||||
return {
|
||||
users: paginatedUsers,
|
||||
total: users.length,
|
||||
page,
|
||||
limit,
|
||||
totalPages: Math.ceil(users.length / limit)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Error getting all users:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔄 更新用户状态
|
||||
async updateUserStatus(userId, isActive) {
|
||||
try {
|
||||
const user = await this.getUserById(userId, false) // Skip usage calculation
|
||||
if (!user) {
|
||||
throw new Error('User not found')
|
||||
}
|
||||
|
||||
user.isActive = isActive
|
||||
user.updatedAt = new Date().toISOString()
|
||||
|
||||
await redis.set(`${this.userPrefix}${userId}`, JSON.stringify(user))
|
||||
logger.info(`🔄 Updated user status: ${user.username} -> ${isActive ? 'active' : 'disabled'}`)
|
||||
|
||||
// 如果禁用用户,删除所有会话并禁用其所有API Keys
|
||||
if (!isActive) {
|
||||
await this.invalidateUserSessions(userId)
|
||||
|
||||
// Disable all user's API keys when user is disabled
|
||||
try {
|
||||
const apiKeyService = require('./apiKeyService')
|
||||
const result = await apiKeyService.disableUserApiKeys(userId)
|
||||
logger.info(`🔑 Disabled ${result.count} API keys for disabled user: ${user.username}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ Error disabling user API keys during user disable:', error)
|
||||
}
|
||||
}
|
||||
|
||||
return user
|
||||
} catch (error) {
|
||||
logger.error('❌ Error updating user status:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔄 更新用户角色
|
||||
async updateUserRole(userId, role) {
|
||||
try {
|
||||
const user = await this.getUserById(userId, false) // Skip usage calculation
|
||||
if (!user) {
|
||||
throw new Error('User not found')
|
||||
}
|
||||
|
||||
user.role = role
|
||||
user.updatedAt = new Date().toISOString()
|
||||
|
||||
await redis.set(`${this.userPrefix}${userId}`, JSON.stringify(user))
|
||||
logger.info(`🔄 Updated user role: ${user.username} -> ${role}`)
|
||||
|
||||
return user
|
||||
} catch (error) {
|
||||
logger.error('❌ Error updating user role:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 更新用户API Key数量 (已废弃,现在通过聚合计算)
|
||||
async updateUserApiKeyCount(userId, _count) {
|
||||
// This method is deprecated since apiKeyCount is now calculated dynamically
|
||||
// in getUserById by aggregating the user's API keys
|
||||
logger.debug(
|
||||
`📊 updateUserApiKeyCount called for ${userId} but is now deprecated (count auto-calculated)`
|
||||
)
|
||||
}
|
||||
|
||||
// 📝 记录用户登录
|
||||
async recordUserLogin(userId) {
|
||||
try {
|
||||
const user = await this.getUserById(userId, false) // Skip usage calculation
|
||||
if (!user) {
|
||||
return
|
||||
}
|
||||
|
||||
user.lastLoginAt = new Date().toISOString()
|
||||
await redis.set(`${this.userPrefix}${userId}`, JSON.stringify(user))
|
||||
} catch (error) {
|
||||
logger.error('❌ Error recording user login:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 🎫 创建用户会话
|
||||
async createUserSession(userId, sessionData = {}) {
|
||||
try {
|
||||
const sessionToken = this.generateSessionToken()
|
||||
const session = {
|
||||
token: sessionToken,
|
||||
userId,
|
||||
createdAt: new Date().toISOString(),
|
||||
expiresAt: new Date(Date.now() + config.userManagement.userSessionTimeout).toISOString(),
|
||||
...sessionData
|
||||
}
|
||||
|
||||
const ttl = Math.floor(config.userManagement.userSessionTimeout / 1000)
|
||||
await redis.setex(`${this.userSessionPrefix}${sessionToken}`, ttl, JSON.stringify(session))
|
||||
|
||||
logger.info(`🎫 Created session for user: ${userId}`)
|
||||
return sessionToken
|
||||
} catch (error) {
|
||||
logger.error('❌ Error creating user session:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🎫 验证用户会话
|
||||
async validateUserSession(sessionToken) {
|
||||
try {
|
||||
const sessionData = await redis.get(`${this.userSessionPrefix}${sessionToken}`)
|
||||
if (!sessionData) {
|
||||
return null
|
||||
}
|
||||
|
||||
const session = JSON.parse(sessionData)
|
||||
|
||||
// 检查会话是否过期
|
||||
if (new Date() > new Date(session.expiresAt)) {
|
||||
await this.invalidateUserSession(sessionToken)
|
||||
return null
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
const user = await this.getUserById(session.userId, false) // Skip usage calculation for validation
|
||||
if (!user || !user.isActive) {
|
||||
await this.invalidateUserSession(sessionToken)
|
||||
return null
|
||||
}
|
||||
|
||||
return { session, user }
|
||||
} catch (error) {
|
||||
logger.error('❌ Error validating user session:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
// 🚫 使用户会话失效
|
||||
async invalidateUserSession(sessionToken) {
|
||||
try {
|
||||
await redis.del(`${this.userSessionPrefix}${sessionToken}`)
|
||||
logger.info(`🚫 Invalidated session: ${sessionToken}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ Error invalidating user session:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 🚫 使用户所有会话失效
|
||||
async invalidateUserSessions(userId) {
|
||||
try {
|
||||
const client = redis.getClientSafe()
|
||||
const pattern = `${this.userSessionPrefix}*`
|
||||
const keys = await client.keys(pattern)
|
||||
|
||||
for (const key of keys) {
|
||||
const sessionData = await client.get(key)
|
||||
if (sessionData) {
|
||||
const session = JSON.parse(sessionData)
|
||||
if (session.userId === userId) {
|
||||
await client.del(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`🚫 Invalidated all sessions for user: ${userId}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ Error invalidating user sessions:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 🗑️ 删除用户(软删除,标记为不活跃)
|
||||
async deleteUser(userId) {
|
||||
try {
|
||||
const user = await this.getUserById(userId, false) // Skip usage calculation
|
||||
if (!user) {
|
||||
throw new Error('User not found')
|
||||
}
|
||||
|
||||
// 软删除:标记为不活跃并添加删除时间戳
|
||||
user.isActive = false
|
||||
user.deletedAt = new Date().toISOString()
|
||||
user.updatedAt = new Date().toISOString()
|
||||
|
||||
await redis.set(`${this.userPrefix}${userId}`, JSON.stringify(user))
|
||||
|
||||
// 删除所有会话
|
||||
await this.invalidateUserSessions(userId)
|
||||
|
||||
// Disable all user's API keys when user is deleted
|
||||
try {
|
||||
const apiKeyService = require('./apiKeyService')
|
||||
const result = await apiKeyService.disableUserApiKeys(userId)
|
||||
logger.info(`🔑 Disabled ${result.count} API keys for deleted user: ${user.username}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ Error disabling user API keys during user deletion:', error)
|
||||
}
|
||||
|
||||
logger.info(`🗑️ Soft deleted user: ${user.username} (${userId})`)
|
||||
return user
|
||||
} catch (error) {
|
||||
logger.error('❌ Error deleting user:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 获取用户统计信息
|
||||
async getUserStats() {
|
||||
try {
|
||||
const client = redis.getClientSafe()
|
||||
const pattern = `${this.userPrefix}*`
|
||||
const keys = await client.keys(pattern)
|
||||
|
||||
const stats = {
|
||||
totalUsers: 0,
|
||||
activeUsers: 0,
|
||||
adminUsers: 0,
|
||||
regularUsers: 0,
|
||||
totalApiKeys: 0,
|
||||
totalUsage: {
|
||||
requests: 0,
|
||||
inputTokens: 0,
|
||||
outputTokens: 0,
|
||||
totalCost: 0
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of keys) {
|
||||
const userData = await client.get(key)
|
||||
if (userData) {
|
||||
const user = JSON.parse(userData)
|
||||
stats.totalUsers++
|
||||
|
||||
if (user.isActive) {
|
||||
stats.activeUsers++
|
||||
}
|
||||
|
||||
if (user.role === 'admin') {
|
||||
stats.adminUsers++
|
||||
} else {
|
||||
stats.regularUsers++
|
||||
}
|
||||
|
||||
// Calculate dynamic usage stats for each user
|
||||
try {
|
||||
const usageStats = await this.calculateUserUsageStats(user.id)
|
||||
stats.totalApiKeys += usageStats.apiKeyCount
|
||||
stats.totalUsage.requests += usageStats.totalUsage.requests
|
||||
stats.totalUsage.inputTokens += usageStats.totalUsage.inputTokens
|
||||
stats.totalUsage.outputTokens += usageStats.totalUsage.outputTokens
|
||||
stats.totalUsage.totalCost += usageStats.totalUsage.totalCost
|
||||
} catch (error) {
|
||||
logger.error(`❌ Error calculating usage for user ${user.id} in stats:`, error)
|
||||
// Fallback to stored values if calculation fails
|
||||
stats.totalApiKeys += user.apiKeyCount || 0
|
||||
stats.totalUsage.requests += user.totalUsage?.requests || 0
|
||||
stats.totalUsage.inputTokens += user.totalUsage?.inputTokens || 0
|
||||
stats.totalUsage.outputTokens += user.totalUsage?.outputTokens || 0
|
||||
stats.totalUsage.totalCost += user.totalUsage?.totalCost || 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
} catch (error) {
|
||||
logger.error('❌ Error getting user stats:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new UserService()
|
||||
Reference in New Issue
Block a user