mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 09:38:02 +00:00
fix: apikey的服务权限问题修复
This commit is contained in:
@@ -10,6 +10,40 @@ const sessionHelper = require('../utils/sessionHelper')
|
||||
// 导入 geminiRoutes 中导出的处理函数
|
||||
const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes')
|
||||
|
||||
// 检查 API Key 是否具备 Gemini 权限
|
||||
function hasGeminiPermission(apiKeyData, requiredPermission = 'gemini') {
|
||||
const permissions = apiKeyData?.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 确保请求拥有 Gemini 权限
|
||||
function ensureGeminiPermission(req, res) {
|
||||
const apiKeyData = req.apiKey || {}
|
||||
if (hasGeminiPermission(apiKeyData, 'gemini')) {
|
||||
return true
|
||||
}
|
||||
|
||||
logger.security(
|
||||
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}`
|
||||
)
|
||||
|
||||
res.status(403).json({
|
||||
error: {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// 供路由中间件复用的权限检查
|
||||
function ensureGeminiPermissionMiddleware(req, res, next) {
|
||||
if (ensureGeminiPermission(req, res)) {
|
||||
return next()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 标准 Gemini API 路由处理器
|
||||
// 这些路由将挂载在 /gemini 路径下,处理标准 Gemini API 格式的请求
|
||||
// 标准格式: /gemini/v1beta/models/{model}:generateContent
|
||||
@@ -17,6 +51,10 @@ const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('
|
||||
// 专门处理标准 Gemini API 格式的 generateContent
|
||||
async function handleStandardGenerateContent(req, res) {
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 从路径参数中获取模型名
|
||||
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
@@ -225,6 +263,10 @@ async function handleStandardStreamGenerateContent(req, res) {
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 从路径参数中获取模型名
|
||||
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
@@ -535,31 +577,48 @@ async function handleStandardStreamGenerateContent(req, res) {
|
||||
}
|
||||
|
||||
// v1beta 版本的标准路由 - 支持动态模型名称
|
||||
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:loadCodeAssist',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
|
||||
handleLoadCodeAssist(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:onboardUser',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
|
||||
handleOnboardUser(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:countTokens',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
|
||||
handleCountTokens(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
// 使用专门的处理函数处理标准 Gemini API 格式
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:generateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardGenerateContent
|
||||
)
|
||||
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:streamGenerateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardStreamGenerateContent
|
||||
)
|
||||
|
||||
@@ -567,45 +626,52 @@ router.post(
|
||||
router.post(
|
||||
'/v1/models/:modelName\\:generateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardGenerateContent
|
||||
)
|
||||
|
||||
router.post(
|
||||
'/v1/models/:modelName\\:streamGenerateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardStreamGenerateContent
|
||||
)
|
||||
|
||||
router.post('/v1/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1/models/:modelName\\:countTokens',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1): ${req.method} ${req.originalUrl}`)
|
||||
handleCountTokens(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
// v1internal 版本的标准路由(这些使用原有的处理函数,因为格式不同)
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
handleLoadCodeAssist(req, res, next)
|
||||
})
|
||||
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
handleOnboardUser(req, res, next)
|
||||
})
|
||||
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
handleCountTokens(req, res, next)
|
||||
})
|
||||
|
||||
// v1internal 使用不同的处理逻辑,因为它们不包含模型在 URL 中
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
// v1internal 格式不同,使用原有的处理函数
|
||||
const { handleGenerateContent } = require('./geminiRoutes')
|
||||
handleGenerateContent(req, res, next)
|
||||
})
|
||||
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
// v1internal 格式不同,使用原有的处理函数
|
||||
const { handleStreamGenerateContent } = require('./geminiRoutes')
|
||||
@@ -613,32 +679,37 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res
|
||||
})
|
||||
|
||||
// 添加标准 Gemini API 的模型列表端点
|
||||
router.get('/v1beta/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
logger.info('Standard Gemini API models request')
|
||||
// 直接调用 geminiRoutes 中的模型处理逻辑
|
||||
const geminiRoutes = require('./geminiRoutes')
|
||||
const modelHandler = geminiRoutes.stack.find(
|
||||
(layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get
|
||||
)
|
||||
if (modelHandler && modelHandler.route.stack[1]) {
|
||||
// 调用处理函数(跳过第一个 authenticateApiKey 中间件)
|
||||
modelHandler.route.stack[1].handle(req, res)
|
||||
} else {
|
||||
res.status(500).json({ error: 'Models handler not found' })
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in standard models endpoint:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'api_error'
|
||||
router.get(
|
||||
'/v1beta/models',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
async (req, res) => {
|
||||
try {
|
||||
logger.info('Standard Gemini API models request')
|
||||
// 直接调用 geminiRoutes 中的模型处理逻辑
|
||||
const geminiRoutes = require('./geminiRoutes')
|
||||
const modelHandler = geminiRoutes.stack.find(
|
||||
(layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get
|
||||
)
|
||||
if (modelHandler && modelHandler.route.stack[1]) {
|
||||
// 调用处理函数(跳过第一个 authenticateApiKey 中间件)
|
||||
modelHandler.route.stack[1].handle(req, res)
|
||||
} else {
|
||||
res.status(500).json({ error: 'Models handler not found' })
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in standard models endpoint:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'api_error'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
router.get('/v1/models', authenticateApiKey, ensureGeminiPermissionMiddleware, async (req, res) => {
|
||||
try {
|
||||
logger.info('Standard Gemini API models request (v1)')
|
||||
// 直接调用 geminiRoutes 中的模型处理逻辑
|
||||
@@ -663,7 +734,7 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
})
|
||||
|
||||
// 添加模型详情端点
|
||||
router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => {
|
||||
router.get('/v1beta/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => {
|
||||
const { modelName } = req.params
|
||||
logger.info(`Standard Gemini API model details request: ${modelName}`)
|
||||
|
||||
@@ -681,7 +752,7 @@ router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => {
|
||||
})
|
||||
})
|
||||
|
||||
router.get('/v1/models/:modelName', authenticateApiKey, (req, res) => {
|
||||
router.get('/v1/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => {
|
||||
const { modelName } = req.params
|
||||
logger.info(`Standard Gemini API model details request (v1): ${modelName}`)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user