diff --git a/src/middleware/auth.js b/src/middleware/auth.js index 47bd4333..dc441781 100644 --- a/src/middleware/auth.js +++ b/src/middleware/auth.js @@ -11,6 +11,7 @@ const authenticateApiKey = async (req, res, next) => { try { // 安全提取API Key,支持多种格式 const apiKey = req.headers['x-api-key'] || + req.headers['x-goog-api-key'] || req.headers['authorization']?.replace(/^Bearer\s+/i, '') || req.headers['api-key']; diff --git a/src/routes/geminiRoutes.js b/src/routes/geminiRoutes.js index eefc48a9..ba10dd01 100644 --- a/src/routes/geminiRoutes.js +++ b/src/routes/geminiRoutes.js @@ -291,9 +291,9 @@ router.get('/key-info', authenticateApiKey, async (req, res) => { } }); -router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, async (req, res) => { +// 共用的 loadCodeAssist 处理函数 +async function handleLoadCodeAssist(req, res) { try { - const sessionHash = sessionHelper.generateSessionHash(req.body); // 使用统一调度选择账号(传递请求的模型) @@ -304,7 +304,8 @@ router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, async (req, res) const { metadata, cloudaicompanionProject } = req.body; - logger.info('LoadCodeAssist request', { + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.info(`LoadCodeAssist request (${version})`, { metadata: metadata || {}, cloudaicompanionProject: cloudaicompanionProject || null, apiKeyId: req.apiKey?.id || 'unknown' @@ -314,15 +315,17 @@ router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, async (req, res) const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject); res.json(response); } catch (error) { - logger.error('Error in loadCodeAssist endpoint', { error: error.message }); + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.error(`Error in loadCodeAssist endpoint (${version})`, { error: error.message }); res.status(500).json({ error: 'Internal server error', message: error.message }); } -}); +} -router.post('/v1internal\\:onboardUser', authenticateApiKey, async (req, res) => { +// 共用的 onboardUser 处理函数 +async function handleOnboardUser(req, res) { try { const { tierId, cloudaicompanionProject, metadata } = req.body; const sessionHash = sessionHelper.generateSessionHash(req.body); @@ -332,7 +335,8 @@ router.post('/v1internal\\:onboardUser', authenticateApiKey, async (req, res) => const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel); const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId); - logger.info('OnboardUser request', { + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.info(`OnboardUser request (${version})`, { tierId: tierId || 'not provided', cloudaicompanionProject: cloudaicompanionProject || null, metadata: metadata || {}, @@ -351,15 +355,17 @@ router.post('/v1internal\\:onboardUser', authenticateApiKey, async (req, res) => res.json(response); } } catch (error) { - logger.error('Error in onboardUser endpoint', { error: error.message }); + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.error(`Error in onboardUser endpoint (${version})`, { error: error.message }); res.status(500).json({ error: 'Internal server error', message: error.message }); } -}); +} -router.post('/v1internal\\:countTokens', authenticateApiKey, async (req, res) => { +// 共用的 countTokens 处理函数 +async function handleCountTokens(req, res) { try { // 处理请求体结构,支持直接 contents 或 request.contents const requestData = req.body.request || req.body; @@ -380,7 +386,8 @@ router.post('/v1internal\\:countTokens', authenticateApiKey, async (req, res) => const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model); const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId); - logger.info('CountTokens request', { + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.info(`CountTokens request (${version})`, { model: model, contentsLength: contents.length, apiKeyId: req.apiKey?.id || 'unknown' @@ -391,7 +398,8 @@ router.post('/v1internal\\:countTokens', authenticateApiKey, async (req, res) => res.json(response); } catch (error) { - logger.error('Error in countTokens endpoint', { error: error.message }); + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.error(`Error in countTokens endpoint (${version})`, { error: error.message }); res.status(500).json({ error: { message: error.message || 'Internal server error', @@ -399,15 +407,39 @@ router.post('/v1internal\\:countTokens', authenticateApiKey, async (req, res) => } }); } -}); +} -router.post('/v1internal\\:generateContent', authenticateApiKey, async (req, res) => { +// 共用的 generateContent 处理函数 +async function handleGenerateContent(req, res) { try { const { model, project, user_prompt_id, request: requestData } = req.body; const sessionHash = sessionHelper.generateSessionHash(req.body); + + // 处理不同格式的请求 + let actualRequestData = requestData; + if (!requestData) { + if (req.body.messages) { + // 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象 + actualRequestData = { + contents: req.body.messages.map(msg => ({ + role: msg.role === 'assistant' ? 'model' : msg.role, + parts: [{ text: msg.content }] + })), + generationConfig: { + temperature: req.body.temperature !== undefined ? req.body.temperature : 0.7, + maxOutputTokens: req.body.max_tokens !== undefined ? req.body.max_tokens : 4096, + topP: req.body.top_p !== undefined ? req.body.top_p : 0.95, + topK: req.body.top_k !== undefined ? req.body.top_k : 40 + } + }; + } else if (req.body.contents) { + // 直接的 Gemini 格式请求(没有 request 包装) + actualRequestData = req.body; + } + } // 验证必需参数 - if (!requestData || !requestData.contents) { + if (!actualRequestData || !actualRequestData.contents) { return res.status(400).json({ error: { message: 'Request contents are required', @@ -421,7 +453,8 @@ router.post('/v1internal\\:generateContent', authenticateApiKey, async (req, res const account = await geminiAccountService.getAccount(accountId); const { accessToken, refreshToken } = account; - logger.info('GenerateContent request', { + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.info(`GenerateContent request (${version})`, { model: model, userPromptId: user_prompt_id, projectId: project || account.projectId, @@ -431,7 +464,7 @@ router.post('/v1internal\\:generateContent', authenticateApiKey, async (req, res const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); const response = await geminiAccountService.generateContent( client, - { model, request: requestData }, + { model, request: actualRequestData }, user_prompt_id, project || account.projectId, req.apiKey?.id // 使用 API Key ID 作为 session ID @@ -439,7 +472,9 @@ router.post('/v1internal\\:generateContent', authenticateApiKey, async (req, res res.json(response); } catch (error) { - logger.error('Error in generateContent endpoint', { error: error.message }); + console.log(321, error.response); + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.error(`Error in generateContent endpoint (${version})`, { error: error.message }); res.status(500).json({ error: { message: error.message || 'Internal server error', @@ -447,17 +482,41 @@ router.post('/v1internal\\:generateContent', authenticateApiKey, async (req, res } }); } -}); +} -router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, async (req, res) => { +// 共用的 streamGenerateContent 处理函数 +async function handleStreamGenerateContent(req, res) { let abortController = null; try { const { model, project, user_prompt_id, request: requestData } = req.body; const sessionHash = sessionHelper.generateSessionHash(req.body); + // 处理不同格式的请求 + let actualRequestData = requestData; + if (!requestData) { + if (req.body.messages) { + // 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象 + actualRequestData = { + contents: req.body.messages.map(msg => ({ + role: msg.role === 'assistant' ? 'model' : msg.role, + parts: [{ text: msg.content }] + })), + generationConfig: { + temperature: req.body.temperature !== undefined ? req.body.temperature : 0.7, + maxOutputTokens: req.body.max_tokens !== undefined ? req.body.max_tokens : 4096, + topP: req.body.top_p !== undefined ? req.body.top_p : 0.95, + topK: req.body.top_k !== undefined ? req.body.top_k : 40 + } + }; + } else if (req.body.contents) { + // 直接的 Gemini 格式请求(没有 request 包装) + actualRequestData = req.body; + } + } + // 验证必需参数 - if (!requestData || !requestData.contents) { + if (!actualRequestData || !actualRequestData.contents) { return res.status(400).json({ error: { message: 'Request contents are required', @@ -471,7 +530,8 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, async (re const account = await geminiAccountService.getAccount(accountId); const { accessToken, refreshToken } = account; - logger.info('StreamGenerateContent request', { + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.info(`StreamGenerateContent request (${version})`, { model: model, userPromptId: user_prompt_id, projectId: project || account.projectId, @@ -492,7 +552,7 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, async (re const client = await geminiAccountService.getOauthClient(accessToken, refreshToken); const streamResponse = await geminiAccountService.generateContentStream( client, - { model, request: requestData }, + { model, request: actualRequestData }, user_prompt_id, project || account.projectId, req.apiKey?.id, // 使用 API Key ID 作为 session ID @@ -528,7 +588,8 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, async (re }); } catch (error) { - logger.error('Error in streamGenerateContent endpoint', { error: error.message }); + const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'; + logger.error(`Error in streamGenerateContent endpoint (${version})`, { error: error.message }); if (!res.headersSent) { res.status(500).json({ @@ -544,6 +605,21 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, async (re abortController = null; } } -}); +} + +// 注册所有路由端点 +// v1internal 版本的端点 +router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist); +router.post('/v1internal\\:onboardUser', authenticateApiKey, handleOnboardUser); +router.post('/v1internal\\:countTokens', authenticateApiKey, handleCountTokens); +router.post('/v1internal\\:generateContent', authenticateApiKey, handleGenerateContent); +router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent); + +// v1beta 版本的端点 - 支持动态模型名称 +router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist); +router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, handleOnboardUser); +router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, handleCountTokens); +router.post('/v1beta/models/:modelName\\:generateContent', authenticateApiKey, handleGenerateContent); +router.post('/v1beta/models/:modelName\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent); module.exports = router; \ No newline at end of file