diff --git a/src/routes/geminiRoutes.js b/src/routes/geminiRoutes.js index 7aebc7ae..df447fb7 100644 --- a/src/routes/geminiRoutes.js +++ b/src/routes/geminiRoutes.js @@ -29,6 +29,26 @@ function checkPermissions(apiKeyData, requiredPermission = 'gemini') { return permissions === 'all' || permissions === requiredPermission } +// 确保请求具有 Gemini 访问权限 +function ensureGeminiPermission(req, res) { + const apiKeyData = req.apiKey || {} + if (checkPermissions(apiKeyData, 'gemini')) { + return true + } + + logger.security( + `🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}` + ) + + res.status(403).json({ + error: { + message: 'This API key does not have permission to access Gemini', + type: 'permission_denied' + } + }) + return false +} + // Gemini 消息处理端点 router.post('/messages', authenticateApiKey, async (req, res) => { const startTime = Date.now() @@ -309,6 +329,10 @@ router.get('/key-info', authenticateApiKey, async (req, res) => { // 共用的 loadCodeAssist 处理函数 async function handleLoadCodeAssist(req, res) { try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + const sessionHash = sessionHelper.generateSessionHash(req.body) // 从路径参数或请求体中获取模型名 @@ -388,6 +412,10 @@ async function handleLoadCodeAssist(req, res) { // 共用的 onboardUser 处理函数 async function handleOnboardUser(req, res) { try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + // 提取请求参数 const { tierId, cloudaicompanionProject, metadata } = req.body const sessionHash = sessionHelper.generateSessionHash(req.body) @@ -475,6 +503,10 @@ async function handleOnboardUser(req, res) { // 共用的 countTokens 处理函数 async function handleCountTokens(req, res) { try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + // 处理请求体结构,支持直接 contents 或 request.contents const requestData = req.body.request || req.body const { contents } = requestData @@ -538,6 +570,10 @@ async function handleCountTokens(req, res) { // 共用的 generateContent 处理函数 async function handleGenerateContent(req, res) { try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + const { project, user_prompt_id, request: requestData } = req.body // 从路径参数或请求体中获取模型名 const model = req.body.model || req.params.modelName || 'gemini-2.5-flash' @@ -676,6 +712,10 @@ async function handleStreamGenerateContent(req, res) { let abortController = null try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + const { project, user_prompt_id, request: requestData } = req.body // 从路径参数或请求体中获取模型名 const model = req.body.model || req.params.modelName || 'gemini-2.5-flash' diff --git a/src/routes/openaiRoutes.js b/src/routes/openaiRoutes.js index eff686f8..684c8327 100644 --- a/src/routes/openaiRoutes.js +++ b/src/routes/openaiRoutes.js @@ -17,6 +17,12 @@ function createProxyAgent(proxy) { return ProxyHelper.createProxyAgent(proxy) } +// 检查 API Key 是否具备 OpenAI 权限 +function checkOpenAIPermissions(apiKeyData) { + const permissions = apiKeyData?.permissions || 'all' + return permissions === 'all' || permissions === 'openai' +} + function normalizeHeaders(headers = {}) { if (!headers || typeof headers !== 'object') { return {} @@ -190,6 +196,19 @@ const handleResponses = async (req, res) => { // 从中间件获取 API Key 数据 const apiKeyData = req.apiKey || {} + if (!checkOpenAIPermissions(apiKeyData)) { + logger.security( + `🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 OpenAI 权限,拒绝访问 ${req.originalUrl}` + ) + return res.status(403).json({ + error: { + message: 'This API key does not have permission to access OpenAI', + type: 'permission_denied', + code: 'permission_denied' + } + }) + } + // 从请求头或请求体中提取会话 ID const sessionId = req.headers['session_id'] || diff --git a/src/routes/standardGeminiRoutes.js b/src/routes/standardGeminiRoutes.js index 8d049574..dc981ebe 100644 --- a/src/routes/standardGeminiRoutes.js +++ b/src/routes/standardGeminiRoutes.js @@ -10,6 +10,40 @@ const sessionHelper = require('../utils/sessionHelper') // 导入 geminiRoutes 中导出的处理函数 const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes') +// 检查 API Key 是否具备 Gemini 权限 +function hasGeminiPermission(apiKeyData, requiredPermission = 'gemini') { + const permissions = apiKeyData?.permissions || 'all' + return permissions === 'all' || permissions === requiredPermission +} + +// 确保请求拥有 Gemini 权限 +function ensureGeminiPermission(req, res) { + const apiKeyData = req.apiKey || {} + if (hasGeminiPermission(apiKeyData, 'gemini')) { + return true + } + + logger.security( + `🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}` + ) + + res.status(403).json({ + error: { + message: 'This API key does not have permission to access Gemini', + type: 'permission_denied' + } + }) + return false +} + +// 供路由中间件复用的权限检查 +function ensureGeminiPermissionMiddleware(req, res, next) { + if (ensureGeminiPermission(req, res)) { + return next() + } + return undefined +} + // 标准 Gemini API 路由处理器 // 这些路由将挂载在 /gemini 路径下,处理标准 Gemini API 格式的请求 // 标准格式: /gemini/v1beta/models/{model}:generateContent @@ -17,6 +51,10 @@ const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require(' // 专门处理标准 Gemini API 格式的 generateContent async function handleStandardGenerateContent(req, res) { try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + // 从路径参数中获取模型名 const model = req.params.modelName || 'gemini-2.0-flash-exp' const sessionHash = sessionHelper.generateSessionHash(req.body) @@ -225,6 +263,10 @@ async function handleStandardStreamGenerateContent(req, res) { let abortController = null try { + if (!ensureGeminiPermission(req, res)) { + return undefined + } + // 从路径参数中获取模型名 const model = req.params.modelName || 'gemini-2.0-flash-exp' const sessionHash = sessionHelper.generateSessionHash(req.body) @@ -535,31 +577,48 @@ async function handleStandardStreamGenerateContent(req, res) { } // v1beta 版本的标准路由 - 支持动态模型名称 -router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, (req, res, next) => { +router.post( + '/v1beta/models/:modelName\\:loadCodeAssist', + authenticateApiKey, + ensureGeminiPermissionMiddleware, + (req, res, next) => { logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`) handleLoadCodeAssist(req, res, next) -}) + } +) -router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, (req, res, next) => { +router.post( + '/v1beta/models/:modelName\\:onboardUser', + authenticateApiKey, + ensureGeminiPermissionMiddleware, + (req, res, next) => { logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`) handleOnboardUser(req, res, next) -}) + } +) -router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => { +router.post( + '/v1beta/models/:modelName\\:countTokens', + authenticateApiKey, + ensureGeminiPermissionMiddleware, + (req, res, next) => { logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`) handleCountTokens(req, res, next) -}) + } +) // 使用专门的处理函数处理标准 Gemini API 格式 router.post( '/v1beta/models/:modelName\\:generateContent', authenticateApiKey, + ensureGeminiPermissionMiddleware, handleStandardGenerateContent ) router.post( '/v1beta/models/:modelName\\:streamGenerateContent', authenticateApiKey, + ensureGeminiPermissionMiddleware, handleStandardStreamGenerateContent ) @@ -567,45 +626,52 @@ router.post( router.post( '/v1/models/:modelName\\:generateContent', authenticateApiKey, + ensureGeminiPermissionMiddleware, handleStandardGenerateContent ) router.post( '/v1/models/:modelName\\:streamGenerateContent', authenticateApiKey, + ensureGeminiPermissionMiddleware, handleStandardStreamGenerateContent ) -router.post('/v1/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => { +router.post( + '/v1/models/:modelName\\:countTokens', + authenticateApiKey, + ensureGeminiPermissionMiddleware, + (req, res, next) => { logger.info(`Standard Gemini API request (v1): ${req.method} ${req.originalUrl}`) handleCountTokens(req, res, next) -}) + } +) // v1internal 版本的标准路由(这些使用原有的处理函数,因为格式不同) -router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, (req, res, next) => { +router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => { logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) handleLoadCodeAssist(req, res, next) }) -router.post('/v1internal\\:onboardUser', authenticateApiKey, (req, res, next) => { +router.post('/v1internal\\:onboardUser', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => { logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) handleOnboardUser(req, res, next) }) -router.post('/v1internal\\:countTokens', authenticateApiKey, (req, res, next) => { +router.post('/v1internal\\:countTokens', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => { logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) handleCountTokens(req, res, next) }) // v1internal 使用不同的处理逻辑,因为它们不包含模型在 URL 中 -router.post('/v1internal\\:generateContent', authenticateApiKey, (req, res, next) => { +router.post('/v1internal\\:generateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => { logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) // v1internal 格式不同,使用原有的处理函数 const { handleGenerateContent } = require('./geminiRoutes') handleGenerateContent(req, res, next) }) -router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res, next) => { +router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => { logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) // v1internal 格式不同,使用原有的处理函数 const { handleStreamGenerateContent } = require('./geminiRoutes') @@ -613,32 +679,37 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res }) // 添加标准 Gemini API 的模型列表端点 -router.get('/v1beta/models', authenticateApiKey, async (req, res) => { - try { - logger.info('Standard Gemini API models request') - // 直接调用 geminiRoutes 中的模型处理逻辑 - const geminiRoutes = require('./geminiRoutes') - const modelHandler = geminiRoutes.stack.find( - (layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get - ) - if (modelHandler && modelHandler.route.stack[1]) { - // 调用处理函数(跳过第一个 authenticateApiKey 中间件) - modelHandler.route.stack[1].handle(req, res) - } else { - res.status(500).json({ error: 'Models handler not found' }) - } - } catch (error) { - logger.error('Error in standard models endpoint:', error) - res.status(500).json({ - error: { - message: 'Failed to retrieve models', - type: 'api_error' +router.get( + '/v1beta/models', + authenticateApiKey, + ensureGeminiPermissionMiddleware, + async (req, res) => { + try { + logger.info('Standard Gemini API models request') + // 直接调用 geminiRoutes 中的模型处理逻辑 + const geminiRoutes = require('./geminiRoutes') + const modelHandler = geminiRoutes.stack.find( + (layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get + ) + if (modelHandler && modelHandler.route.stack[1]) { + // 调用处理函数(跳过第一个 authenticateApiKey 中间件) + modelHandler.route.stack[1].handle(req, res) + } else { + res.status(500).json({ error: 'Models handler not found' }) } - }) + } catch (error) { + logger.error('Error in standard models endpoint:', error) + res.status(500).json({ + error: { + message: 'Failed to retrieve models', + type: 'api_error' + } + }) + } } -}) +) -router.get('/v1/models', authenticateApiKey, async (req, res) => { +router.get('/v1/models', authenticateApiKey, ensureGeminiPermissionMiddleware, async (req, res) => { try { logger.info('Standard Gemini API models request (v1)') // 直接调用 geminiRoutes 中的模型处理逻辑 @@ -663,7 +734,7 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => { }) // 添加模型详情端点 -router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => { +router.get('/v1beta/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => { const { modelName } = req.params logger.info(`Standard Gemini API model details request: ${modelName}`) @@ -681,7 +752,7 @@ router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => { }) }) -router.get('/v1/models/:modelName', authenticateApiKey, (req, res) => { +router.get('/v1/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => { const { modelName } = req.params logger.info(`Standard Gemini API model details request (v1): ${modelName}`)