fix: apikey的服务权限问题修复

This commit is contained in:
shaw
2025-09-25 22:51:39 +08:00
parent 66bb3419b7
commit 25d1c3f74e
3 changed files with 168 additions and 38 deletions

View File

@@ -29,6 +29,26 @@ function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
return permissions === 'all' || permissions === requiredPermission return permissions === 'all' || permissions === requiredPermission
} }
// 确保请求具有 Gemini 访问权限
function ensureGeminiPermission(req, res) {
const apiKeyData = req.apiKey || {}
if (checkPermissions(apiKeyData, 'gemini')) {
return true
}
logger.security(
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}`
)
res.status(403).json({
error: {
message: 'This API key does not have permission to access Gemini',
type: 'permission_denied'
}
})
return false
}
// Gemini 消息处理端点 // Gemini 消息处理端点
router.post('/messages', authenticateApiKey, async (req, res) => { router.post('/messages', authenticateApiKey, async (req, res) => {
const startTime = Date.now() const startTime = Date.now()
@@ -309,6 +329,10 @@ router.get('/key-info', authenticateApiKey, async (req, res) => {
// 共用的 loadCodeAssist 处理函数 // 共用的 loadCodeAssist 处理函数
async function handleLoadCodeAssist(req, res) { async function handleLoadCodeAssist(req, res) {
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
const sessionHash = sessionHelper.generateSessionHash(req.body) const sessionHash = sessionHelper.generateSessionHash(req.body)
// 从路径参数或请求体中获取模型名 // 从路径参数或请求体中获取模型名
@@ -388,6 +412,10 @@ async function handleLoadCodeAssist(req, res) {
// 共用的 onboardUser 处理函数 // 共用的 onboardUser 处理函数
async function handleOnboardUser(req, res) { async function handleOnboardUser(req, res) {
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
// 提取请求参数 // 提取请求参数
const { tierId, cloudaicompanionProject, metadata } = req.body const { tierId, cloudaicompanionProject, metadata } = req.body
const sessionHash = sessionHelper.generateSessionHash(req.body) const sessionHash = sessionHelper.generateSessionHash(req.body)
@@ -475,6 +503,10 @@ async function handleOnboardUser(req, res) {
// 共用的 countTokens 处理函数 // 共用的 countTokens 处理函数
async function handleCountTokens(req, res) { async function handleCountTokens(req, res) {
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
// 处理请求体结构,支持直接 contents 或 request.contents // 处理请求体结构,支持直接 contents 或 request.contents
const requestData = req.body.request || req.body const requestData = req.body.request || req.body
const { contents } = requestData const { contents } = requestData
@@ -538,6 +570,10 @@ async function handleCountTokens(req, res) {
// 共用的 generateContent 处理函数 // 共用的 generateContent 处理函数
async function handleGenerateContent(req, res) { async function handleGenerateContent(req, res) {
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
const { project, user_prompt_id, request: requestData } = req.body const { project, user_prompt_id, request: requestData } = req.body
// 从路径参数或请求体中获取模型名 // 从路径参数或请求体中获取模型名
const model = req.body.model || req.params.modelName || 'gemini-2.5-flash' const model = req.body.model || req.params.modelName || 'gemini-2.5-flash'
@@ -676,6 +712,10 @@ async function handleStreamGenerateContent(req, res) {
let abortController = null let abortController = null
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
const { project, user_prompt_id, request: requestData } = req.body const { project, user_prompt_id, request: requestData } = req.body
// 从路径参数或请求体中获取模型名 // 从路径参数或请求体中获取模型名
const model = req.body.model || req.params.modelName || 'gemini-2.5-flash' const model = req.body.model || req.params.modelName || 'gemini-2.5-flash'

View File

@@ -17,6 +17,12 @@ function createProxyAgent(proxy) {
return ProxyHelper.createProxyAgent(proxy) return ProxyHelper.createProxyAgent(proxy)
} }
// 检查 API Key 是否具备 OpenAI 权限
function checkOpenAIPermissions(apiKeyData) {
const permissions = apiKeyData?.permissions || 'all'
return permissions === 'all' || permissions === 'openai'
}
function normalizeHeaders(headers = {}) { function normalizeHeaders(headers = {}) {
if (!headers || typeof headers !== 'object') { if (!headers || typeof headers !== 'object') {
return {} return {}
@@ -190,6 +196,19 @@ const handleResponses = async (req, res) => {
// 从中间件获取 API Key 数据 // 从中间件获取 API Key 数据
const apiKeyData = req.apiKey || {} const apiKeyData = req.apiKey || {}
if (!checkOpenAIPermissions(apiKeyData)) {
logger.security(
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 OpenAI 权限,拒绝访问 ${req.originalUrl}`
)
return res.status(403).json({
error: {
message: 'This API key does not have permission to access OpenAI',
type: 'permission_denied',
code: 'permission_denied'
}
})
}
// 从请求头或请求体中提取会话 ID // 从请求头或请求体中提取会话 ID
const sessionId = const sessionId =
req.headers['session_id'] || req.headers['session_id'] ||

View File

@@ -10,6 +10,40 @@ const sessionHelper = require('../utils/sessionHelper')
// 导入 geminiRoutes 中导出的处理函数 // 导入 geminiRoutes 中导出的处理函数
const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes') const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes')
// 检查 API Key 是否具备 Gemini 权限
function hasGeminiPermission(apiKeyData, requiredPermission = 'gemini') {
const permissions = apiKeyData?.permissions || 'all'
return permissions === 'all' || permissions === requiredPermission
}
// 确保请求拥有 Gemini 权限
function ensureGeminiPermission(req, res) {
const apiKeyData = req.apiKey || {}
if (hasGeminiPermission(apiKeyData, 'gemini')) {
return true
}
logger.security(
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}`
)
res.status(403).json({
error: {
message: 'This API key does not have permission to access Gemini',
type: 'permission_denied'
}
})
return false
}
// 供路由中间件复用的权限检查
function ensureGeminiPermissionMiddleware(req, res, next) {
if (ensureGeminiPermission(req, res)) {
return next()
}
return undefined
}
// 标准 Gemini API 路由处理器 // 标准 Gemini API 路由处理器
// 这些路由将挂载在 /gemini 路径下,处理标准 Gemini API 格式的请求 // 这些路由将挂载在 /gemini 路径下,处理标准 Gemini API 格式的请求
// 标准格式: /gemini/v1beta/models/{model}:generateContent // 标准格式: /gemini/v1beta/models/{model}:generateContent
@@ -17,6 +51,10 @@ const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('
// 专门处理标准 Gemini API 格式的 generateContent // 专门处理标准 Gemini API 格式的 generateContent
async function handleStandardGenerateContent(req, res) { async function handleStandardGenerateContent(req, res) {
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
// 从路径参数中获取模型名 // 从路径参数中获取模型名
const model = req.params.modelName || 'gemini-2.0-flash-exp' const model = req.params.modelName || 'gemini-2.0-flash-exp'
const sessionHash = sessionHelper.generateSessionHash(req.body) const sessionHash = sessionHelper.generateSessionHash(req.body)
@@ -225,6 +263,10 @@ async function handleStandardStreamGenerateContent(req, res) {
let abortController = null let abortController = null
try { try {
if (!ensureGeminiPermission(req, res)) {
return undefined
}
// 从路径参数中获取模型名 // 从路径参数中获取模型名
const model = req.params.modelName || 'gemini-2.0-flash-exp' const model = req.params.modelName || 'gemini-2.0-flash-exp'
const sessionHash = sessionHelper.generateSessionHash(req.body) const sessionHash = sessionHelper.generateSessionHash(req.body)
@@ -535,31 +577,48 @@ async function handleStandardStreamGenerateContent(req, res) {
} }
// v1beta 版本的标准路由 - 支持动态模型名称 // v1beta 版本的标准路由 - 支持动态模型名称
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, (req, res, next) => { router.post(
'/v1beta/models/:modelName\\:loadCodeAssist',
authenticateApiKey,
ensureGeminiPermissionMiddleware,
(req, res, next) => {
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
handleLoadCodeAssist(req, res, next) handleLoadCodeAssist(req, res, next)
}) }
)
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, (req, res, next) => { router.post(
'/v1beta/models/:modelName\\:onboardUser',
authenticateApiKey,
ensureGeminiPermissionMiddleware,
(req, res, next) => {
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
handleOnboardUser(req, res, next) handleOnboardUser(req, res, next)
}) }
)
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => { router.post(
'/v1beta/models/:modelName\\:countTokens',
authenticateApiKey,
ensureGeminiPermissionMiddleware,
(req, res, next) => {
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
handleCountTokens(req, res, next) handleCountTokens(req, res, next)
}) }
)
// 使用专门的处理函数处理标准 Gemini API 格式 // 使用专门的处理函数处理标准 Gemini API 格式
router.post( router.post(
'/v1beta/models/:modelName\\:generateContent', '/v1beta/models/:modelName\\:generateContent',
authenticateApiKey, authenticateApiKey,
ensureGeminiPermissionMiddleware,
handleStandardGenerateContent handleStandardGenerateContent
) )
router.post( router.post(
'/v1beta/models/:modelName\\:streamGenerateContent', '/v1beta/models/:modelName\\:streamGenerateContent',
authenticateApiKey, authenticateApiKey,
ensureGeminiPermissionMiddleware,
handleStandardStreamGenerateContent handleStandardStreamGenerateContent
) )
@@ -567,45 +626,52 @@ router.post(
router.post( router.post(
'/v1/models/:modelName\\:generateContent', '/v1/models/:modelName\\:generateContent',
authenticateApiKey, authenticateApiKey,
ensureGeminiPermissionMiddleware,
handleStandardGenerateContent handleStandardGenerateContent
) )
router.post( router.post(
'/v1/models/:modelName\\:streamGenerateContent', '/v1/models/:modelName\\:streamGenerateContent',
authenticateApiKey, authenticateApiKey,
ensureGeminiPermissionMiddleware,
handleStandardStreamGenerateContent handleStandardStreamGenerateContent
) )
router.post('/v1/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => { router.post(
'/v1/models/:modelName\\:countTokens',
authenticateApiKey,
ensureGeminiPermissionMiddleware,
(req, res, next) => {
logger.info(`Standard Gemini API request (v1): ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request (v1): ${req.method} ${req.originalUrl}`)
handleCountTokens(req, res, next) handleCountTokens(req, res, next)
}) }
)
// v1internal 版本的标准路由(这些使用原有的处理函数,因为格式不同) // v1internal 版本的标准路由(这些使用原有的处理函数,因为格式不同)
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, (req, res, next) => { router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
handleLoadCodeAssist(req, res, next) handleLoadCodeAssist(req, res, next)
}) })
router.post('/v1internal\\:onboardUser', authenticateApiKey, (req, res, next) => { router.post('/v1internal\\:onboardUser', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
handleOnboardUser(req, res, next) handleOnboardUser(req, res, next)
}) })
router.post('/v1internal\\:countTokens', authenticateApiKey, (req, res, next) => { router.post('/v1internal\\:countTokens', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
handleCountTokens(req, res, next) handleCountTokens(req, res, next)
}) })
// v1internal 使用不同的处理逻辑,因为它们不包含模型在 URL 中 // v1internal 使用不同的处理逻辑,因为它们不包含模型在 URL 中
router.post('/v1internal\\:generateContent', authenticateApiKey, (req, res, next) => { router.post('/v1internal\\:generateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
// v1internal 格式不同,使用原有的处理函数 // v1internal 格式不同,使用原有的处理函数
const { handleGenerateContent } = require('./geminiRoutes') const { handleGenerateContent } = require('./geminiRoutes')
handleGenerateContent(req, res, next) handleGenerateContent(req, res, next)
}) })
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res, next) => { router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`) logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
// v1internal 格式不同,使用原有的处理函数 // v1internal 格式不同,使用原有的处理函数
const { handleStreamGenerateContent } = require('./geminiRoutes') const { handleStreamGenerateContent } = require('./geminiRoutes')
@@ -613,32 +679,37 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res
}) })
// 添加标准 Gemini API 的模型列表端点 // 添加标准 Gemini API 的模型列表端点
router.get('/v1beta/models', authenticateApiKey, async (req, res) => { router.get(
try { '/v1beta/models',
logger.info('Standard Gemini API models request') authenticateApiKey,
// 直接调用 geminiRoutes 中的模型处理逻辑 ensureGeminiPermissionMiddleware,
const geminiRoutes = require('./geminiRoutes') async (req, res) => {
const modelHandler = geminiRoutes.stack.find( try {
(layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get logger.info('Standard Gemini API models request')
) // 直接调用 geminiRoutes 中的模型处理逻辑
if (modelHandler && modelHandler.route.stack[1]) { const geminiRoutes = require('./geminiRoutes')
// 调用处理函数(跳过第一个 authenticateApiKey 中间件) const modelHandler = geminiRoutes.stack.find(
modelHandler.route.stack[1].handle(req, res) (layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get
} else { )
res.status(500).json({ error: 'Models handler not found' }) if (modelHandler && modelHandler.route.stack[1]) {
} // 调用处理函数(跳过第一个 authenticateApiKey 中间件)
} catch (error) { modelHandler.route.stack[1].handle(req, res)
logger.error('Error in standard models endpoint:', error) } else {
res.status(500).json({ res.status(500).json({ error: 'Models handler not found' })
error: {
message: 'Failed to retrieve models',
type: 'api_error'
} }
}) } catch (error) {
logger.error('Error in standard models endpoint:', error)
res.status(500).json({
error: {
message: 'Failed to retrieve models',
type: 'api_error'
}
})
}
} }
}) )
router.get('/v1/models', authenticateApiKey, async (req, res) => { router.get('/v1/models', authenticateApiKey, ensureGeminiPermissionMiddleware, async (req, res) => {
try { try {
logger.info('Standard Gemini API models request (v1)') logger.info('Standard Gemini API models request (v1)')
// 直接调用 geminiRoutes 中的模型处理逻辑 // 直接调用 geminiRoutes 中的模型处理逻辑
@@ -663,7 +734,7 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
}) })
// 添加模型详情端点 // 添加模型详情端点
router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => { router.get('/v1beta/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => {
const { modelName } = req.params const { modelName } = req.params
logger.info(`Standard Gemini API model details request: ${modelName}`) logger.info(`Standard Gemini API model details request: ${modelName}`)
@@ -681,7 +752,7 @@ router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => {
}) })
}) })
router.get('/v1/models/:modelName', authenticateApiKey, (req, res) => { router.get('/v1/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => {
const { modelName } = req.params const { modelName } = req.params
logger.info(`Standard Gemini API model details request (v1): ${modelName}`) logger.info(`Standard Gemini API model details request (v1): ${modelName}`)