From 33837c23aa5aa8d41a1027095b31ea93dbef6265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=83=E7=BE=BD?= Date: Mon, 4 Aug 2025 14:47:03 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BA=20Gemini=20=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E6=94=AF=E6=8C=81=E5=B9=B6=E6=B7=BB=E5=8A=A0=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E8=B0=83=E5=BA=A6=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 unifiedGeminiScheduler.js 统一账户调度服务 - 增强 geminiRoutes.js 支持更多 Gemini API 端点 - 优化 geminiAccountService.js 账户管理和 token 刷新机制 - 添加对 v1internal 端点的完整支持(loadCodeAssist、onboardUser、countTokens、generateContent、streamGenerateContent) - 改进错误处理和流式响应管理 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/routes/geminiRoutes.js | 318 +++++++++++++-- src/services/geminiAccountService.js | 534 ++++++++++++++++++++----- src/services/unifiedGeminiScheduler.js | 376 +++++++++++++++++ 3 files changed, 1100 insertions(+), 128 deletions(-) create mode 100644 src/services/unifiedGeminiScheduler.js diff --git a/src/routes/geminiRoutes.js b/src/routes/geminiRoutes.js index 45e707e1..c48a85fc 100644 --- a/src/routes/geminiRoutes.js +++ b/src/routes/geminiRoutes.js @@ -5,6 +5,9 @@ const { authenticateApiKey } = require('../middleware/auth'); const geminiAccountService = require('../services/geminiAccountService'); const { sendGeminiRequest, getAvailableModels } = require('../services/geminiRelayService'); const crypto = require('crypto'); +const sessionHelper = require('../utils/sessionHelper'); +const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler'); +const { OAuth2Client } = require('google-auth-library'); // 生成会话哈希 function generateSessionHash(req) { @@ -13,7 +16,7 @@ function generateSessionHash(req) { req.ip, req.headers['x-api-key']?.substring(0, 10) ].filter(Boolean).join(':'); - + return crypto.createHash('sha256').update(sessionData).digest('hex'); } @@ -27,10 +30,10 @@ function checkPermissions(apiKeyData, requiredPermission = 'gemini') { router.post('/messages', authenticateApiKey, async (req, res) => { const startTime = Date.now(); let abortController = null; - + try { const apiKeyData = req.apiKey; - + // 检查权限 if (!checkPermissions(apiKeyData, 'gemini')) { return res.status(403).json({ @@ -40,7 +43,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => { } }); } - + // 提取请求参数 const { messages, @@ -49,7 +52,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => { max_tokens = 4096, stream = false } = req.body; - + // 验证必需参数 if (!messages || !Array.isArray(messages) || messages.length === 0) { return res.status(400).json({ @@ -59,16 +62,16 @@ router.post('/messages', authenticateApiKey, async (req, res) => { } }); } - + // 生成会话哈希用于粘性会话 const sessionHash = generateSessionHash(req); - + // 选择可用的 Gemini 账户 const account = await geminiAccountService.selectAvailableAccount( apiKeyData.id, sessionHash ); - + if (!account) { return res.status(503).json({ error: { @@ -77,15 +80,15 @@ router.post('/messages', authenticateApiKey, async (req, res) => { } }); } - + logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`); - + // 标记账户被使用 await geminiAccountService.markAccountUsed(account.id); - + // 创建中止控制器 abortController = new AbortController(); - + // 处理客户端断开连接 req.on('close', () => { if (abortController && !abortController.signal.aborted) { @@ -93,7 +96,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => { abortController.abort(); } }); - + // 发送请求到 Gemini const geminiResponse = await sendGeminiRequest({ messages, @@ -107,14 +110,14 @@ router.post('/messages', authenticateApiKey, async (req, res) => { signal: abortController.signal, projectId: account.projectId }); - + if (stream) { // 设置流式响应头 res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache'); res.setHeader('Connection', 'keep-alive'); res.setHeader('X-Accel-Buffering', 'no'); - + // 流式传输响应 for await (const chunk of geminiResponse) { if (abortController.signal.aborted) { @@ -122,26 +125,26 @@ router.post('/messages', authenticateApiKey, async (req, res) => { } res.write(chunk); } - + res.end(); } else { // 非流式响应 res.json(geminiResponse); } - + const duration = Date.now() - startTime; logger.info(`Gemini request completed in ${duration}ms`); - + } catch (error) { logger.error('Gemini request error:', error); - + // 处理速率限制 if (error.status === 429) { if (req.apiKey && req.account) { await geminiAccountService.setAccountRateLimited(req.account.id, true); } } - + // 返回错误响应 const status = error.status || 500; const errorResponse = { @@ -150,7 +153,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => { type: 'api_error' } }; - + res.status(status).json(errorResponse); } finally { // 清理资源 @@ -164,7 +167,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => { router.get('/models', authenticateApiKey, async (req, res) => { try { const apiKeyData = req.apiKey; - + // 检查权限 if (!checkPermissions(apiKeyData, 'gemini')) { return res.status(403).json({ @@ -174,10 +177,10 @@ router.get('/models', authenticateApiKey, async (req, res) => { } }); } - + // 选择账户获取模型列表 const account = await geminiAccountService.selectAvailableAccount(apiKeyData.id); - + if (!account) { // 返回默认模型列表 return res.json({ @@ -192,15 +195,15 @@ router.get('/models', authenticateApiKey, async (req, res) => { ] }); } - + // 获取模型列表 const models = await getAvailableModels(account.accessToken, account.proxy); - + res.json({ object: 'list', data: models }); - + } catch (error) { logger.error('Failed to get Gemini models:', error); res.status(500).json({ @@ -216,7 +219,7 @@ router.get('/models', authenticateApiKey, async (req, res) => { router.get('/usage', authenticateApiKey, async (req, res) => { try { const usage = req.apiKey.usage; - + res.json({ object: 'usage', total_tokens: usage.total.tokens, @@ -241,14 +244,14 @@ router.get('/usage', authenticateApiKey, async (req, res) => { router.get('/key-info', authenticateApiKey, async (req, res) => { try { const keyData = req.apiKey; - + res.json({ id: keyData.id, name: keyData.name, permissions: keyData.permissions || 'all', token_limit: keyData.tokenLimit, tokens_used: keyData.usage.total.tokens, - tokens_remaining: keyData.tokenLimit > 0 + tokens_remaining: keyData.tokenLimit > 0 ? Math.max(0, keyData.tokenLimit - keyData.usage.total.tokens) : null, rate_limit: { @@ -272,4 +275,259 @@ router.get('/key-info', authenticateApiKey, async (req, res) => { } }); +router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, async (req, res) => { + try { + + const sessionHash = sessionHelper.generateSessionHash(req.body); + + // 使用统一调度选择账号(传递请求的模型) + const requestedModel = req.body.model; + const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel); + const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId); + logger.info(`accessToken: ${accessToken}`); + + const { metadata, cloudaicompanionProject } = req.body; + + logger.info('LoadCodeAssist request', { + metadata: metadata || {}, + cloudaicompanionProject: cloudaicompanionProject || null, + apiKeyId: req.apiKey?.id || 'unknown' + }); + + const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); + const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject); + res.json(response); + } catch (error) { + logger.error('Error in loadCodeAssist endpoint', { error: error.message }); + res.status(500).json({ + error: 'Internal server error', + message: error.message + }); + } +}); + +router.post('/v1internal\\:onboardUser', authenticateApiKey, async (req, res) => { + try { + const { tierId, cloudaicompanionProject, metadata } = req.body; + const sessionHash = sessionHelper.generateSessionHash(req.body); + + // 使用统一调度选择账号(传递请求的模型) + const requestedModel = req.body.model; + const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel); + const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId); + + logger.info('OnboardUser request', { + tierId: tierId || 'not provided', + cloudaicompanionProject: cloudaicompanionProject || null, + metadata: metadata || {}, + apiKeyId: req.apiKey?.id || 'unknown' + }); + + const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); + + // 如果提供了完整参数,直接调用onboardUser + if (tierId && metadata) { + const response = await geminiAccountService.onboardUser(client, tierId, cloudaicompanionProject, metadata); + res.json(response); + } else { + // 否则执行完整的setupUser流程 + const response = await geminiAccountService.setupUser(client, cloudaicompanionProject, metadata); + res.json(response); + } + } catch (error) { + logger.error('Error in onboardUser endpoint', { error: error.message }); + res.status(500).json({ + error: 'Internal server error', + message: error.message + }); + } +}); + +router.post('/v1internal\\:countTokens', authenticateApiKey, async (req, res) => { + try { + // 处理请求体结构,支持直接 contents 或 request.contents + const requestData = req.body.request || req.body; + const { contents, model = 'gemini-2.0-flash-exp' } = requestData; + const sessionHash = sessionHelper.generateSessionHash(req.body); + + // 验证必需参数 + if (!contents || !Array.isArray(contents)) { + return res.status(400).json({ + error: { + message: 'Contents array is required', + type: 'invalid_request_error' + } + }); + } + + // 使用统一调度选择账号 + const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model); + const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId); + + logger.info('CountTokens request', { + model: model, + contentsLength: contents.length, + apiKeyId: req.apiKey?.id || 'unknown' + }); + + const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); + const response = await geminiAccountService.countTokens(client, contents, model); + + res.json(response); + } catch (error) { + logger.error('Error in countTokens endpoint', { error: error.message }); + res.status(500).json({ + error: { + message: error.message || 'Internal server error', + type: 'api_error' + } + }); + } +}); + +router.post('/v1internal\\:generateContent', authenticateApiKey, async (req, res) => { + try { + const { model, project, user_prompt_id, request: requestData } = req.body; + const sessionHash = sessionHelper.generateSessionHash(req.body); + + // 验证必需参数 + if (!requestData || !requestData.contents) { + return res.status(400).json({ + error: { + message: 'Request contents are required', + type: 'invalid_request_error' + } + }); + } + + // 使用统一调度选择账号 + const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model); + const account = await geminiAccountService.getAccount(accountId); + const { accessToken, refreshToken } = account; + + logger.info('GenerateContent request', { + model: model, + userPromptId: user_prompt_id, + projectId: project || account.projectId, + apiKeyId: req.apiKey?.id || 'unknown' + }); + + const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); + const response = await geminiAccountService.generateContent( + client, + { model, request: requestData }, + user_prompt_id, + project || account.projectId, + req.apiKey?.id // 使用 API Key ID 作为 session ID + ); + + res.json(response); + } catch (error) { + logger.error('Error in generateContent endpoint', { error: error.message }); + res.status(500).json({ + error: { + message: error.message || 'Internal server error', + type: 'api_error' + } + }); + } +}); + +router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, async (req, res) => { + let abortController = null; + + try { + const { model, project, user_prompt_id, request: requestData } = req.body; + const sessionHash = sessionHelper.generateSessionHash(req.body); + + // 验证必需参数 + if (!requestData || !requestData.contents) { + return res.status(400).json({ + error: { + message: 'Request contents are required', + type: 'invalid_request_error' + } + }); + } + + // 使用统一调度选择账号 + const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model); + const account = await geminiAccountService.getAccount(accountId); + const { accessToken, refreshToken } = account; + + logger.info('StreamGenerateContent request', { + model: model, + userPromptId: user_prompt_id, + projectId: project || account.projectId, + apiKeyId: req.apiKey?.id || 'unknown' + }); + + // 创建中止控制器 + abortController = new AbortController(); + + // 处理客户端断开连接 + req.on('close', () => { + if (abortController && !abortController.signal.aborted) { + logger.info('Client disconnected, aborting stream request'); + abortController.abort(); + } + }); + + const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); + const streamResponse = await geminiAccountService.generateContentStream( + client, + { model, request: requestData }, + user_prompt_id, + project || account.projectId, + req.apiKey?.id, // 使用 API Key ID 作为 session ID + abortController.signal // 传递中止信号 + ); + + // 设置 SSE 响应头 + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('X-Accel-Buffering', 'no'); + + // 直接管道转发流式响应,不进行额外处理 + streamResponse.pipe(res, { end: false }); + + streamResponse.on('end', () => { + logger.info('Stream completed successfully'); + res.end(); + }); + + streamResponse.on('error', (error) => { + logger.error('Stream error:', error); + if (!res.headersSent) { + res.status(500).json({ + error: { + message: error.message || 'Stream error', + type: 'api_error' + } + }); + } else { + res.end(); + } + }); + + } catch (error) { + logger.error('Error in streamGenerateContent endpoint', { error: error.message }); + + if (!res.headersSent) { + res.status(500).json({ + error: { + message: error.message || 'Internal server error', + type: 'api_error' + } + }); + } + } finally { + // 清理资源 + if (abortController) { + abortController = null; + } + } +}); + module.exports = router; \ No newline at end of file diff --git a/src/services/geminiAccountService.js b/src/services/geminiAccountService.js index 07b25f48..dfb2bb32 100644 --- a/src/services/geminiAccountService.js +++ b/src/services/geminiAccountService.js @@ -53,7 +53,7 @@ function decrypt(text) { // IV 是固定长度的 32 个十六进制字符(16 字节) const ivHex = text.substring(0, 32); const encryptedHex = text.substring(33); // 跳过冒号 - + const iv = Buffer.from(ivHex, 'hex'); const encryptedText = Buffer.from(encryptedHex, 'hex'); const decipher = crypto.createDecipheriv(ALGORITHM, key, iv); @@ -82,11 +82,11 @@ async function generateAuthUrl(state = null, redirectUri = null) { // 使用新的 redirect URI const finalRedirectUri = redirectUri || 'https://codeassist.google.com/authcode'; const oAuth2Client = createOAuth2Client(finalRedirectUri); - + // 生成 PKCE code verifier const codeVerifier = await oAuth2Client.generateCodeVerifierAsync(); const stateValue = state || crypto.randomBytes(32).toString('hex'); - + const authUrl = oAuth2Client.generateAuthUrl({ redirect_uri: finalRedirectUri, access_type: 'offline', @@ -96,7 +96,7 @@ async function generateAuthUrl(state = null, redirectUri = null) { state: stateValue, prompt: 'select_account' }); - + return { authUrl, state: stateValue, @@ -109,28 +109,28 @@ async function generateAuthUrl(state = null, redirectUri = null) { async function pollAuthorizationStatus(sessionId, maxAttempts = 60, interval = 2000) { let attempts = 0; const client = redisClient.getClientSafe(); - + while (attempts < maxAttempts) { try { const sessionData = await client.get(`oauth_session:${sessionId}`); if (!sessionData) { throw new Error('OAuth session not found'); } - + const session = JSON.parse(sessionData); if (session.code) { // 授权码已获取,交换 tokens const tokens = await exchangeCodeForTokens(session.code); - + // 清理 session await client.del(`oauth_session:${sessionId}`); - + return { success: true, tokens }; } - + if (session.error) { // 授权失败 await client.del(`oauth_session:${sessionId}`); @@ -139,7 +139,7 @@ async function pollAuthorizationStatus(sessionId, maxAttempts = 60, interval = 2 error: session.error }; } - + // 等待下一次轮询 await new Promise(resolve => setTimeout(resolve, interval)); attempts++; @@ -148,7 +148,7 @@ async function pollAuthorizationStatus(sessionId, maxAttempts = 60, interval = 2 throw error; } } - + // 超时 await client.del(`oauth_session:${sessionId}`); return { @@ -160,20 +160,20 @@ async function pollAuthorizationStatus(sessionId, maxAttempts = 60, interval = 2 // 交换授权码获取 tokens (支持 PKCE) async function exchangeCodeForTokens(code, redirectUri = null, codeVerifier = null) { const oAuth2Client = createOAuth2Client(redirectUri); - + try { const tokenParams = { code: code, redirect_uri: redirectUri }; - + // 如果提供了 codeVerifier,添加到参数中 if (codeVerifier) { tokenParams.codeVerifier = codeVerifier; } - + const { tokens } = await oAuth2Client.getToken(tokenParams); - + // 转换为兼容格式 return { access_token: tokens.access_token, @@ -191,24 +191,24 @@ async function exchangeCodeForTokens(code, redirectUri = null, codeVerifier = nu // 刷新访问令牌 async function refreshAccessToken(refreshToken) { const oAuth2Client = createOAuth2Client(); - + try { // 设置 refresh_token oAuth2Client.setCredentials({ refresh_token: refreshToken }); - + // 调用 refreshAccessToken 获取新的 tokens const response = await oAuth2Client.refreshAccessToken(); const credentials = response.credentials; - + // 检查是否成功获取了新的 access_token if (!credentials || !credentials.access_token) { throw new Error('No access token returned from refresh'); } - + logger.info(`🔄 Successfully refreshed Gemini token. New expiry: ${new Date(credentials.expiry_date).toISOString()}`); - + return { access_token: credentials.access_token, refresh_token: credentials.refresh_token || refreshToken, // 保留原 refresh_token 如果没有返回新的 @@ -230,34 +230,34 @@ async function refreshAccessToken(refreshToken) { async function createAccount(accountData) { const id = uuidv4(); const now = new Date().toISOString(); - + // 处理凭证数据 let geminiOauth = null; let accessToken = ''; let refreshToken = ''; let expiresAt = ''; - + if (accountData.geminiOauth || accountData.accessToken) { // 如果提供了完整的 OAuth 数据 if (accountData.geminiOauth) { - geminiOauth = typeof accountData.geminiOauth === 'string' - ? accountData.geminiOauth + geminiOauth = typeof accountData.geminiOauth === 'string' + ? accountData.geminiOauth : JSON.stringify(accountData.geminiOauth); - - const oauthData = typeof accountData.geminiOauth === 'string' + + const oauthData = typeof accountData.geminiOauth === 'string' ? JSON.parse(accountData.geminiOauth) : accountData.geminiOauth; - + accessToken = oauthData.access_token || ''; refreshToken = oauthData.refresh_token || ''; - expiresAt = oauthData.expiry_date + expiresAt = oauthData.expiry_date ? new Date(oauthData.expiry_date).toISOString() : ''; } else { // 如果只提供了 access token accessToken = accountData.accessToken; refreshToken = accountData.refreshToken || ''; - + // 构造完整的 OAuth 数据 geminiOauth = JSON.stringify({ access_token: accessToken, @@ -266,11 +266,11 @@ async function createAccount(accountData) { token_type: accountData.tokenType || 'Bearer', expiry_date: accountData.expiryDate || Date.now() + 3600000 // 默认1小时 }); - + expiresAt = new Date(accountData.expiryDate || Date.now() + 3600000).toISOString(); } } - + const account = { id, platform: 'gemini', // 标识为 Gemini 账户 @@ -279,39 +279,39 @@ async function createAccount(accountData) { accountType: accountData.accountType || 'shared', isActive: 'true', status: 'active', - + // OAuth 相关字段(加密存储) geminiOauth: geminiOauth ? encrypt(geminiOauth) : '', accessToken: accessToken ? encrypt(accessToken) : '', refreshToken: refreshToken ? encrypt(refreshToken) : '', expiresAt, scopes: accountData.scopes || OAUTH_SCOPES.join(' '), - + // 代理设置 proxy: accountData.proxy ? JSON.stringify(accountData.proxy) : '', - + // 项目编号(Google Cloud/Workspace 账号需要) projectId: accountData.projectId || '', - + // 时间戳 createdAt: now, updatedAt: now, lastUsedAt: '', lastRefreshAt: '' }; - + // 保存到 Redis const client = redisClient.getClientSafe(); await client.hset( `${GEMINI_ACCOUNT_KEY_PREFIX}${id}`, account ); - + // 如果是共享账户,添加到共享账户集合 if (account.accountType === 'shared') { await client.sadd(SHARED_GEMINI_ACCOUNTS_KEY, id); } - + logger.info(`Created Gemini account: ${id}`); return account; } @@ -320,11 +320,11 @@ async function createAccount(accountData) { async function getAccount(accountId) { const client = redisClient.getClientSafe(); const accountData = await client.hgetall(`${GEMINI_ACCOUNT_KEY_PREFIX}${accountId}`); - + if (!accountData || Object.keys(accountData).length === 0) { return null; } - + // 解密敏感字段 if (accountData.geminiOauth) { accountData.geminiOauth = decrypt(accountData.geminiOauth); @@ -335,7 +335,7 @@ async function getAccount(accountId) { if (accountData.refreshToken) { accountData.refreshToken = decrypt(accountData.refreshToken); } - + return accountData; } @@ -345,20 +345,20 @@ async function updateAccount(accountId, updates) { if (!existingAccount) { throw new Error('Account not found'); } - + const now = new Date().toISOString(); updates.updatedAt = now; - + // 检查是否新增了 refresh token // existingAccount.refreshToken 已经是解密后的值了(从 getAccount 返回) const oldRefreshToken = existingAccount.refreshToken || ''; let needUpdateExpiry = false; - + // 加密敏感字段 if (updates.geminiOauth) { updates.geminiOauth = encrypt( - typeof updates.geminiOauth === 'string' - ? updates.geminiOauth + typeof updates.geminiOauth === 'string' + ? updates.geminiOauth : JSON.stringify(updates.geminiOauth) ); } @@ -372,7 +372,7 @@ async function updateAccount(accountId, updates) { needUpdateExpiry = true; } } - + // 更新账户类型时处理共享账户集合 const client = redisClient.getClientSafe(); if (updates.accountType && updates.accountType !== existingAccount.accountType) { @@ -382,26 +382,26 @@ async function updateAccount(accountId, updates) { await client.srem(SHARED_GEMINI_ACCOUNTS_KEY, accountId); } } - + // 如果新增了 refresh token,更新过期时间为10分钟 if (needUpdateExpiry) { const newExpiry = new Date(Date.now() + (10 * 60 * 1000)).toISOString(); updates.expiresAt = newExpiry; logger.info(`🔄 New refresh token added for Gemini account ${accountId}, setting expiry to 10 minutes`); } - + // 如果通过 geminiOauth 更新,也要检查是否新增了 refresh token if (updates.geminiOauth && !oldRefreshToken) { - const oauthData = typeof updates.geminiOauth === 'string' + const oauthData = typeof updates.geminiOauth === 'string' ? JSON.parse(decrypt(updates.geminiOauth)) : updates.geminiOauth; - + if (oauthData.refresh_token) { // 如果 expiry_date 设置的时间过长(超过1小时),调整为10分钟 const providedExpiry = oauthData.expiry_date || 0; const now = Date.now(); const oneHour = 60 * 60 * 1000; - + if (providedExpiry - now > oneHour) { const newExpiry = new Date(now + (10 * 60 * 1000)).toISOString(); updates.expiresAt = newExpiry; @@ -409,12 +409,12 @@ async function updateAccount(accountId, updates) { } } } - + await client.hset( `${GEMINI_ACCOUNT_KEY_PREFIX}${accountId}`, updates ); - + logger.info(`Updated Gemini account: ${accountId}`); return { ...existingAccount, ...updates }; } @@ -425,16 +425,16 @@ async function deleteAccount(accountId) { if (!account) { throw new Error('Account not found'); } - + // 从 Redis 删除 const client = redisClient.getClientSafe(); await client.del(`${GEMINI_ACCOUNT_KEY_PREFIX}${accountId}`); - + // 从共享账户集合中移除 if (account.accountType === 'shared') { await client.srem(SHARED_GEMINI_ACCOUNTS_KEY, accountId); } - + // 清理会话映射 const sessionMappings = await client.keys(`${ACCOUNT_SESSION_MAPPING_PREFIX}*`); for (const key of sessionMappings) { @@ -443,7 +443,7 @@ async function deleteAccount(accountId) { await client.del(key); } } - + logger.info(`Deleted Gemini account: ${accountId}`); return true; } @@ -453,7 +453,7 @@ async function getAllAccounts() { const client = redisClient.getClientSafe(); const keys = await client.keys(`${GEMINI_ACCOUNT_KEY_PREFIX}*`); const accounts = []; - + for (const key of keys) { const accountData = await client.hgetall(key); if (accountData && Object.keys(accountData).length > 0) { @@ -466,7 +466,7 @@ async function getAllAccounts() { }); } } - + return accounts; } @@ -478,7 +478,7 @@ async function selectAvailableAccount(apiKeyId, sessionHash = null) { const mappedAccountId = await client.get( `${ACCOUNT_SESSION_MAPPING_PREFIX}${sessionHash}` ); - + if (mappedAccountId) { const account = await getAccount(mappedAccountId); if (account && account.isActive === 'true' && !isTokenExpired(account)) { @@ -487,25 +487,25 @@ async function selectAvailableAccount(apiKeyId, sessionHash = null) { } } } - + // 获取 API Key 信息 const apiKeyData = await client.hgetall(`api_key:${apiKeyId}`); - + // 检查是否绑定了 Gemini 账户 if (apiKeyData.geminiAccountId) { const account = await getAccount(apiKeyData.geminiAccountId); if (account && account.isActive === 'true') { // 检查 token 是否过期 const isExpired = isTokenExpired(account); - + // 记录token使用情况 logTokenUsage(account.id, account.name, 'gemini', account.expiresAt, isExpired); - + if (isExpired) { await refreshAccountToken(account.id); return await getAccount(account.id); } - + // 创建粘性会话映射 if (sessionHash) { await client.setex( @@ -514,46 +514,46 @@ async function selectAvailableAccount(apiKeyId, sessionHash = null) { account.id ); } - + return account; } } - + // 从共享账户池选择 const sharedAccountIds = await client.smembers(SHARED_GEMINI_ACCOUNTS_KEY); const availableAccounts = []; - + for (const accountId of sharedAccountIds) { const account = await getAccount(accountId); if (account && account.isActive === 'true' && !isRateLimited(account)) { availableAccounts.push(account); } } - + if (availableAccounts.length === 0) { throw new Error('No available Gemini accounts'); } - + // 选择最少使用的账户 availableAccounts.sort((a, b) => { const aLastUsed = a.lastUsedAt ? new Date(a.lastUsedAt).getTime() : 0; const bLastUsed = b.lastUsedAt ? new Date(b.lastUsedAt).getTime() : 0; return aLastUsed - bLastUsed; }); - + const selectedAccount = availableAccounts[0]; - + // 检查并刷新 token const isExpired = isTokenExpired(selectedAccount); - + // 记录token使用情况 logTokenUsage(selectedAccount.id, selectedAccount.name, 'gemini', selectedAccount.expiresAt, isExpired); - + if (isExpired) { await refreshAccountToken(selectedAccount.id); return await getAccount(selectedAccount.id); } - + // 创建粘性会话映射 if (sessionHash) { await client.setex( @@ -562,18 +562,18 @@ async function selectAvailableAccount(apiKeyId, sessionHash = null) { selectedAccount.id ); } - + return selectedAccount; } // 检查 token 是否过期 function isTokenExpired(account) { if (!account.expiresAt) return true; - + const expiryTime = new Date(account.expiresAt).getTime(); const now = Date.now(); const buffer = 10 * 1000; // 10秒缓冲 - + return now >= (expiryTime - buffer); } @@ -583,7 +583,7 @@ function isRateLimited(account) { const limitedAt = new Date(account.rateLimitedAt).getTime(); const now = Date.now(); const limitDuration = 60 * 60 * 1000; // 1小时 - + return now < (limitedAt + limitDuration); } return false; @@ -593,28 +593,28 @@ function isRateLimited(account) { async function refreshAccountToken(accountId) { let lockAcquired = false; let account = null; - + try { account = await getAccount(accountId); if (!account) { throw new Error('Account not found'); } - + if (!account.refreshToken) { throw new Error('No refresh token available'); } - + // 尝试获取分布式锁 lockAcquired = await tokenRefreshService.acquireRefreshLock(accountId, 'gemini'); - + if (!lockAcquired) { // 如果无法获取锁,说明另一个进程正在刷新 logger.info(`🔒 Token refresh already in progress for Gemini account: ${account.name} (${accountId})`); logRefreshSkipped(accountId, account.name, 'gemini', 'already_locked'); - + // 等待一段时间后返回,期望其他进程已完成刷新 await new Promise(resolve => setTimeout(resolve, 2000)); - + // 重新获取账户数据(可能已被其他进程刷新) const updatedAccount = await getAccount(accountId); if (updatedAccount && updatedAccount.accessToken) { @@ -627,17 +627,17 @@ async function refreshAccountToken(accountId) { token_type: 'Bearer' }; } - + throw new Error('Token refresh in progress by another process'); } - + // 记录开始刷新 logRefreshStart(accountId, account.name, 'gemini', 'manual_refresh'); logger.info(`🔄 Starting token refresh for Gemini account: ${account.name} (${accountId})`); - + // account.refreshToken 已经是解密后的值(从 getAccount 返回) const newTokens = await refreshAccessToken(account.refreshToken); - + // 更新账户信息 const updates = { accessToken: newTokens.access_token, @@ -648,9 +648,9 @@ async function refreshAccountToken(accountId) { status: 'active', // 刷新成功后,将状态更新为 active errorMessage: '' // 清空错误信息 }; - + await updateAccount(accountId, updates); - + // 记录刷新成功 logRefreshSuccess(accountId, account.name, 'gemini', { accessToken: newTokens.access_token, @@ -658,16 +658,16 @@ async function refreshAccountToken(accountId) { expiresAt: newTokens.expiry_date, scopes: newTokens.scope }); - + logger.info(`Refreshed token for Gemini account: ${accountId} - Access Token: ${maskToken(newTokens.access_token)}`); - + return newTokens; } catch (error) { // 记录刷新失败 logRefreshError(accountId, account ? account.name : 'Unknown', 'gemini', error); - + logger.error(`Failed to refresh token for account ${accountId}:`, error); - + // 标记账户为错误状态(只有在账户存在时) if (account) { try { @@ -679,7 +679,7 @@ async function refreshAccountToken(accountId) { logger.error('Failed to update account status after refresh error:', updateError); } } - + throw error; } finally { // 释放锁 @@ -705,10 +705,340 @@ async function setAccountRateLimited(accountId, isLimited = true) { rateLimitStatus: '', rateLimitedAt: '' }; - + await updateAccount(accountId, updates); } +// 获取配置的OAuth客户端 - 参考GeminiCliSimulator的getOauthClient方法 +async function getOauthClient(accessToken, refreshToken) { + const client = new OAuth2Client({ + clientId: OAUTH_CLIENT_ID, + clientSecret: OAUTH_CLIENT_SECRET, + }); + const creds = { + 'access_token': accessToken, + 'refresh_token': refreshToken, + 'scope': 'https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.profile openid https://www.googleapis.com/auth/userinfo.email', + 'token_type': 'Bearer', + 'expiry_date': 1754269905646 + }; + + // 设置凭据 + client.setCredentials(creds); + + // 验证凭据本地有效性 + const { token } = await client.getAccessToken(); + if (!token) { + return false; + } + + // 验证服务器端token状态(检查是否被撤销) + await client.getTokenInfo(token); + + logger.info('✅ OAuth客户端已创建'); + return client; +} + +// 调用 Google Code Assist API 的 loadCodeAssist 方法 +async function loadCodeAssist(client, projectId = null) { + const axios = require('axios'); + const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'; + const CODE_ASSIST_API_VERSION = 'v1internal'; + + const { token } = await client.getAccessToken(); + + // 创建ClientMetadata + const clientMetadata = { + ideType: 'IDE_UNSPECIFIED', + platform: 'PLATFORM_UNSPECIFIED', + pluginType: 'GEMINI', + duetProject: projectId, + }; + + const request = { + cloudaicompanionProject: projectId, + metadata: clientMetadata, + }; + + const response = await axios({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:loadCodeAssist`, + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + data: request, + timeout: 30000, + }); + + logger.info('📋 loadCodeAssist API调用成功'); + return response.data; +} + +// 获取onboard层级 - 参考GeminiCliSimulator的getOnboardTier方法 +function getOnboardTier(loadRes) { + // 用户层级枚举 + const UserTierId = { + LEGACY: 'LEGACY', + FREE: 'FREE', + PRO: 'PRO' + }; + + if (loadRes.currentTier) { + return loadRes.currentTier; + } + + for (const tier of loadRes.allowedTiers || []) { + if (tier.isDefault) { + return tier; + } + } + + return { + name: '', + description: '', + id: UserTierId.LEGACY, + userDefinedCloudaicompanionProject: true, + }; +} + +// 调用 Google Code Assist API 的 onboardUser 方法(包含轮询逻辑) +async function onboardUser(client, tierId, projectId, clientMetadata) { + const axios = require('axios'); + const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'; + const CODE_ASSIST_API_VERSION = 'v1internal'; + + const { token } = await client.getAccessToken(); + + const onboardReq = { + tierId: tierId, + cloudaicompanionProject: projectId, + metadata: clientMetadata, + }; + + logger.info('📋 开始onboardUser API调用', { tierId, projectId }); + + // 轮询onboardUser直到长运行操作完成 + let lroRes = await axios({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:onboardUser`, + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + data: onboardReq, + timeout: 30000, + }); + + let attempts = 0; + const maxAttempts = 12; // 最多等待1分钟(5秒 * 12次) + + while (!lroRes.data.done && attempts < maxAttempts) { + logger.info(`⏳ 等待onboardUser完成... (${attempts + 1}/${maxAttempts})`); + await new Promise(resolve => setTimeout(resolve, 5000)); + + lroRes = await axios({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:onboardUser`, + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + data: onboardReq, + timeout: 30000, + }); + + attempts++; + } + + if (!lroRes.data.done) { + throw new Error('onboardUser操作超时'); + } + + logger.info('✅ onboardUser API调用完成'); + return lroRes.data; +} + +// 完整的用户设置流程 - 参考setup.ts的逻辑 +async function setupUser(client, initialProjectId = null, clientMetadata = null) { + logger.info('🚀 setupUser 开始', { initialProjectId, hasClientMetadata: !!clientMetadata }); + + let projectId = initialProjectId || process.env.GOOGLE_CLOUD_PROJECT || null; + logger.info('📋 初始项目ID', { projectId, fromEnv: !!process.env.GOOGLE_CLOUD_PROJECT }); + + // 默认的ClientMetadata + if (!clientMetadata) { + clientMetadata = { + ideType: 'IDE_UNSPECIFIED', + platform: 'PLATFORM_UNSPECIFIED', + pluginType: 'GEMINI', + duetProject: projectId, + }; + logger.info('🔧 使用默认 ClientMetadata'); + } + + // 调用loadCodeAssist + logger.info('📞 调用 loadCodeAssist...'); + const loadRes = await loadCodeAssist(client, projectId); + logger.info('✅ loadCodeAssist 完成', { hasCloudaicompanionProject: !!loadRes.cloudaicompanionProject }); + + // 如果没有projectId,尝试从loadRes获取 + if (!projectId && loadRes.cloudaicompanionProject) { + projectId = loadRes.cloudaicompanionProject; + logger.info('📋 从 loadCodeAssist 获取项目ID', { projectId }); + } + + const tier = getOnboardTier(loadRes); + logger.info('🎯 获取用户层级', { tierId: tier.id, userDefinedProject: tier.userDefinedCloudaicompanionProject }); + + if (tier.userDefinedCloudaiCompanionProject && !projectId) { + throw new Error('此账号需要设置GOOGLE_CLOUD_PROJECT环境变量或提供projectId'); + } + + // 调用onboardUser + logger.info('📞 调用 onboardUser...', { tierId: tier.id, projectId }); + const lroRes = await onboardUser(client, tier.id, projectId, clientMetadata); + logger.info('✅ onboardUser 完成', { hasDone: !!lroRes.done, hasResponse: !!lroRes.response }); + + const result = { + projectId: lroRes.response?.cloudaicompanionProject?.id || projectId || '', + userTier: tier.id, + loadRes, + onboardRes: lroRes.response || {} + }; + + logger.info('🎯 setupUser 完成', { resultProjectId: result.projectId, userTier: result.userTier }); + return result; +} + +// 调用 Code Assist API 计算 token 数量 +async function countTokens(client, contents, model = 'gemini-2.0-flash-exp') { + const axios = require('axios'); + const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'; + const CODE_ASSIST_API_VERSION = 'v1internal'; + + const { token } = await client.getAccessToken(); + + // 按照 gemini-cli 的转换格式构造请求 + const request = { + request: { + model: `models/${model}`, + contents: contents + } + }; + + logger.info('📊 countTokens API调用开始', { model, contentsLength: contents.length }); + + const response = await axios({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:countTokens`, + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + data: request, + timeout: 30000, + }); + + logger.info('✅ countTokens API调用成功', { totalTokens: response.data.totalTokens }); + return response.data; +} + +// 调用 Code Assist API 生成内容(非流式) +async function generateContent(client, requestData, userPromptId, projectId = null, sessionId = null) { + const axios = require('axios'); + const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'; + const CODE_ASSIST_API_VERSION = 'v1internal'; + + const { token } = await client.getAccessToken(); + + // 按照 gemini-cli 的转换格式构造请求 + const request = { + model: requestData.model, + project: projectId, + user_prompt_id: userPromptId, + request: { + ...requestData.request, + session_id: sessionId + } + }; + + logger.info('🤖 generateContent API调用开始', { + model: requestData.model, + userPromptId, + projectId, + sessionId + }); + + const response = await axios({ + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:generateContent`, + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + data: request, + timeout: 60000, // 生成内容可能需要更长时间 + }); + + logger.info('✅ generateContent API调用成功'); + return response.data; +} + +// 调用 Code Assist API 生成内容(流式) +async function generateContentStream(client, requestData, userPromptId, projectId = null, sessionId = null, signal = null) { + const axios = require('axios'); + const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'; + const CODE_ASSIST_API_VERSION = 'v1internal'; + + const { token } = await client.getAccessToken(); + + // 按照 gemini-cli 的转换格式构造请求 + const request = { + model: requestData.model, + project: projectId, + user_prompt_id: userPromptId, + request: { + ...requestData.request, + session_id: sessionId + } + }; + + logger.info('🌊 streamGenerateContent API调用开始', { + model: requestData.model, + userPromptId, + projectId, + sessionId + }); + + const axiosConfig = { + url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:streamGenerateContent`, + method: 'POST', + params: { + alt: 'sse' + }, + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + data: request, + responseType: 'stream', + timeout: 60000, + }; + + // 如果提供了中止信号,添加到配置中 + if (signal) { + axiosConfig.signal = signal; + } + + const response = await axios(axiosConfig); + + logger.info('✅ streamGenerateContent API调用成功,开始流式传输'); + return response.data; // 返回流对象 +} + + + module.exports = { generateAuthUrl, pollAuthorizationStatus, @@ -724,6 +1054,14 @@ module.exports = { markAccountUsed, setAccountRateLimited, isTokenExpired, + getOauthClient, + loadCodeAssist, + getOnboardTier, + onboardUser, + setupUser, + countTokens, + generateContent, + generateContentStream, OAUTH_CLIENT_ID, OAUTH_SCOPES }; diff --git a/src/services/unifiedGeminiScheduler.js b/src/services/unifiedGeminiScheduler.js new file mode 100644 index 00000000..3860fed4 --- /dev/null +++ b/src/services/unifiedGeminiScheduler.js @@ -0,0 +1,376 @@ +const geminiAccountService = require('./geminiAccountService'); +const accountGroupService = require('./accountGroupService'); +const redis = require('../models/redis'); +const logger = require('../utils/logger'); + +class UnifiedGeminiScheduler { + constructor() { + this.SESSION_MAPPING_PREFIX = 'unified_gemini_session_mapping:'; + } + + // 🎯 统一调度Gemini账号 + async selectAccountForApiKey(apiKeyData, sessionHash = null, requestedModel = null) { + try { + // 如果API Key绑定了专属账户或分组,优先使用 + if (apiKeyData.geminiAccountId) { + // 检查是否是分组 + if (apiKeyData.geminiAccountId.startsWith('group:')) { + const groupId = apiKeyData.geminiAccountId.replace('group:', ''); + logger.info(`🎯 API key ${apiKeyData.name} is bound to group ${groupId}, selecting from group`); + return await this.selectAccountFromGroup(groupId, sessionHash, requestedModel, apiKeyData); + } + + // 普通专属账户 + const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId); + if (boundAccount && boundAccount.isActive === 'true' && boundAccount.status !== 'error') { + logger.info(`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId}) for API key ${apiKeyData.name}`); + return { + accountId: apiKeyData.geminiAccountId, + accountType: 'gemini' + }; + } else { + logger.warn(`⚠️ Bound Gemini account ${apiKeyData.geminiAccountId} is not available, falling back to pool`); + } + } + + // 如果有会话哈希,检查是否有已映射的账户 + if (sessionHash) { + const mappedAccount = await this._getSessionMapping(sessionHash); + if (mappedAccount) { + // 验证映射的账户是否仍然可用 + const isAvailable = await this._isAccountAvailable(mappedAccount.accountId, mappedAccount.accountType); + if (isAvailable) { + logger.info(`🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`); + return mappedAccount; + } else { + logger.warn(`⚠️ Mapped account ${mappedAccount.accountId} is no longer available, selecting new account`); + await this._deleteSessionMapping(sessionHash); + } + } + } + + // 获取所有可用账户 + const availableAccounts = await this._getAllAvailableAccounts(apiKeyData, requestedModel); + + if (availableAccounts.length === 0) { + // 提供更详细的错误信息 + if (requestedModel) { + throw new Error(`No available Gemini accounts support the requested model: ${requestedModel}`); + } else { + throw new Error('No available Gemini accounts'); + } + } + + // 按优先级和最后使用时间排序 + const sortedAccounts = this._sortAccountsByPriority(availableAccounts); + + // 选择第一个账户 + const selectedAccount = sortedAccounts[0]; + + // 如果有会话哈希,建立新的映射 + if (sessionHash) { + await this._setSessionMapping(sessionHash, selectedAccount.accountId, selectedAccount.accountType); + logger.info(`🎯 Created new sticky session mapping: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`); + } + + logger.info(`🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority} for API key ${apiKeyData.name}`); + + return { + accountId: selectedAccount.accountId, + accountType: selectedAccount.accountType + }; + } catch (error) { + logger.error('❌ Failed to select account for API key:', error); + throw error; + } + } + + // 📋 获取所有可用账户 + async _getAllAvailableAccounts(apiKeyData, requestedModel = null) { + const availableAccounts = []; + + // 如果API Key绑定了专属账户,优先返回 + if (apiKeyData.geminiAccountId) { + const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId); + if (boundAccount && boundAccount.isActive === 'true' && boundAccount.status !== 'error') { + const isRateLimited = await this.isAccountRateLimited(boundAccount.id); + if (!isRateLimited) { + logger.info(`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId})`); + return [{ + ...boundAccount, + accountId: boundAccount.id, + accountType: 'gemini', + priority: parseInt(boundAccount.priority) || 50, + lastUsedAt: boundAccount.lastUsedAt || '0' + }]; + } + } else { + logger.warn(`⚠️ Bound Gemini account ${apiKeyData.geminiAccountId} is not available`); + } + } + + // 获取所有Gemini账户(共享池) + const geminiAccounts = await geminiAccountService.getAllAccounts(); + for (const account of geminiAccounts) { + if (account.isActive === 'true' && + account.status !== 'error' && + (account.accountType === 'shared' || !account.accountType) && // 兼容旧数据 + account.schedulable !== 'false') { // 检查是否可调度 + + // 检查token是否过期 + const isExpired = geminiAccountService.isTokenExpired(account); + if (isExpired && !account.refreshToken) { + logger.warn(`⚠️ Gemini account ${account.name} token expired and no refresh token available`); + continue; + } + + // 检查是否被限流 + const isRateLimited = await this.isAccountRateLimited(account.id); + if (!isRateLimited) { + availableAccounts.push({ + ...account, + accountId: account.id, + accountType: 'gemini', + priority: parseInt(account.priority) || 50, // 默认优先级50 + lastUsedAt: account.lastUsedAt || '0' + }); + } + } + } + + logger.info(`📊 Total available Gemini accounts: ${availableAccounts.length}`); + return availableAccounts; + } + + // 🔢 按优先级和最后使用时间排序账户 + _sortAccountsByPriority(accounts) { + return accounts.sort((a, b) => { + // 首先按优先级排序(数字越小优先级越高) + if (a.priority !== b.priority) { + return a.priority - b.priority; + } + + // 优先级相同时,按最后使用时间排序(最久未使用的优先) + const aLastUsed = new Date(a.lastUsedAt || 0).getTime(); + const bLastUsed = new Date(b.lastUsedAt || 0).getTime(); + return aLastUsed - bLastUsed; + }); + } + + // 🔍 检查账户是否可用 + async _isAccountAvailable(accountId, accountType) { + try { + if (accountType === 'gemini') { + const account = await geminiAccountService.getAccount(accountId); + if (!account || account.isActive !== 'true' || account.status === 'error') { + return false; + } + // 检查是否可调度 + if (account.schedulable === 'false') { + logger.info(`🚫 Gemini account ${accountId} is not schedulable`); + return false; + } + return !(await this.isAccountRateLimited(accountId)); + } + return false; + } catch (error) { + logger.warn(`⚠️ Failed to check account availability: ${accountId}`, error); + return false; + } + } + + // 🔗 获取会话映射 + async _getSessionMapping(sessionHash) { + const client = redis.getClientSafe(); + const mappingData = await client.get(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`); + + if (mappingData) { + try { + return JSON.parse(mappingData); + } catch (error) { + logger.warn('⚠️ Failed to parse session mapping:', error); + return null; + } + } + + return null; + } + + // 💾 设置会话映射 + async _setSessionMapping(sessionHash, accountId, accountType) { + const client = redis.getClientSafe(); + const mappingData = JSON.stringify({ accountId, accountType }); + + // 设置1小时过期 + await client.setex( + `${this.SESSION_MAPPING_PREFIX}${sessionHash}`, + 3600, + mappingData + ); + } + + // 🗑️ 删除会话映射 + async _deleteSessionMapping(sessionHash) { + const client = redis.getClientSafe(); + await client.del(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`); + } + + // 🚫 标记账户为限流状态 + async markAccountRateLimited(accountId, accountType, sessionHash = null) { + try { + if (accountType === 'gemini') { + await geminiAccountService.setAccountRateLimited(accountId, true); + } + + // 删除会话映射 + if (sessionHash) { + await this._deleteSessionMapping(sessionHash); + } + + return { success: true }; + } catch (error) { + logger.error(`❌ Failed to mark account as rate limited: ${accountId} (${accountType})`, error); + throw error; + } + } + + // ✅ 移除账户的限流状态 + async removeAccountRateLimit(accountId, accountType) { + try { + if (accountType === 'gemini') { + await geminiAccountService.setAccountRateLimited(accountId, false); + } + + return { success: true }; + } catch (error) { + logger.error(`❌ Failed to remove rate limit for account: ${accountId} (${accountType})`, error); + throw error; + } + } + + // 🔍 检查账户是否处于限流状态 + async isAccountRateLimited(accountId) { + try { + const account = await geminiAccountService.getAccount(accountId); + if (!account) return false; + + if (account.rateLimitStatus === 'limited' && account.rateLimitedAt) { + const limitedAt = new Date(account.rateLimitedAt).getTime(); + const now = Date.now(); + const limitDuration = 60 * 60 * 1000; // 1小时 + + return now < (limitedAt + limitDuration); + } + return false; + } catch (error) { + logger.error(`❌ Failed to check rate limit status: ${accountId}`, error); + return false; + } + } + + // 👥 从分组中选择账户 + async selectAccountFromGroup(groupId, sessionHash = null, requestedModel = null, apiKeyData = null) { + try { + // 获取分组信息 + const group = await accountGroupService.getGroup(groupId); + if (!group) { + throw new Error(`Group ${groupId} not found`); + } + + if (group.platform !== 'gemini') { + throw new Error(`Group ${group.name} is not a Gemini group`); + } + + logger.info(`👥 Selecting account from Gemini group: ${group.name}`); + + // 如果有会话哈希,检查是否有已映射的账户 + if (sessionHash) { + const mappedAccount = await this._getSessionMapping(sessionHash); + if (mappedAccount) { + // 验证映射的账户是否属于这个分组 + const memberIds = await accountGroupService.getGroupMembers(groupId); + if (memberIds.includes(mappedAccount.accountId)) { + const isAvailable = await this._isAccountAvailable(mappedAccount.accountId, mappedAccount.accountType); + if (isAvailable) { + logger.info(`🎯 Using sticky session account from group: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`); + return mappedAccount; + } + } + // 如果映射的账户不可用或不在分组中,删除映射 + await this._deleteSessionMapping(sessionHash); + } + } + + // 获取分组内的所有账户 + const memberIds = await accountGroupService.getGroupMembers(groupId); + if (memberIds.length === 0) { + throw new Error(`Group ${group.name} has no members`); + } + + const availableAccounts = []; + + // 获取所有成员账户的详细信息 + for (const memberId of memberIds) { + const account = await geminiAccountService.getAccount(memberId); + + if (!account) { + logger.warn(`⚠️ Gemini account ${memberId} not found in group ${group.name}`); + continue; + } + + // 检查账户是否可用 + if (account.isActive === 'true' && + account.status !== 'error' && + account.schedulable !== 'false') { + + // 检查token是否过期 + const isExpired = geminiAccountService.isTokenExpired(account); + if (isExpired && !account.refreshToken) { + logger.warn(`⚠️ Gemini account ${account.name} in group token expired and no refresh token available`); + continue; + } + + // 检查是否被限流 + const isRateLimited = await this.isAccountRateLimited(account.id); + if (!isRateLimited) { + availableAccounts.push({ + ...account, + accountId: account.id, + accountType: 'gemini', + priority: parseInt(account.priority) || 50, + lastUsedAt: account.lastUsedAt || '0' + }); + } + } + } + + if (availableAccounts.length === 0) { + throw new Error(`No available accounts in Gemini group ${group.name}`); + } + + // 使用现有的优先级排序逻辑 + const sortedAccounts = this._sortAccountsByPriority(availableAccounts); + + // 选择第一个账户 + const selectedAccount = sortedAccounts[0]; + + // 如果有会话哈希,建立新的映射 + if (sessionHash) { + await this._setSessionMapping(sessionHash, selectedAccount.accountId, selectedAccount.accountType); + logger.info(`🎯 Created new sticky session mapping in group: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`); + } + + logger.info(`🎯 Selected account from Gemini group ${group.name}: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority}`); + + return { + accountId: selectedAccount.accountId, + accountType: selectedAccount.accountType + }; + } catch (error) { + logger.error(`❌ Failed to select account from Gemini group ${groupId}:`, error); + throw error; + } + } +} + +module.exports = new UnifiedGeminiScheduler(); \ No newline at end of file