mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 00:53:33 +00:00
fix: apikey的服务权限问题修复
This commit is contained in:
@@ -29,6 +29,26 @@ function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 确保请求具有 Gemini 访问权限
|
||||
function ensureGeminiPermission(req, res) {
|
||||
const apiKeyData = req.apiKey || {}
|
||||
if (checkPermissions(apiKeyData, 'gemini')) {
|
||||
return true
|
||||
}
|
||||
|
||||
logger.security(
|
||||
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}`
|
||||
)
|
||||
|
||||
res.status(403).json({
|
||||
error: {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Gemini 消息处理端点
|
||||
router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
const startTime = Date.now()
|
||||
@@ -309,6 +329,10 @@ router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
// 共用的 loadCodeAssist 处理函数
|
||||
async function handleLoadCodeAssist(req, res) {
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 从路径参数或请求体中获取模型名
|
||||
@@ -388,6 +412,10 @@ async function handleLoadCodeAssist(req, res) {
|
||||
// 共用的 onboardUser 处理函数
|
||||
async function handleOnboardUser(req, res) {
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 提取请求参数
|
||||
const { tierId, cloudaicompanionProject, metadata } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
@@ -475,6 +503,10 @@ async function handleOnboardUser(req, res) {
|
||||
// 共用的 countTokens 处理函数
|
||||
async function handleCountTokens(req, res) {
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 处理请求体结构,支持直接 contents 或 request.contents
|
||||
const requestData = req.body.request || req.body
|
||||
const { contents } = requestData
|
||||
@@ -538,6 +570,10 @@ async function handleCountTokens(req, res) {
|
||||
// 共用的 generateContent 处理函数
|
||||
async function handleGenerateContent(req, res) {
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const { project, user_prompt_id, request: requestData } = req.body
|
||||
// 从路径参数或请求体中获取模型名
|
||||
const model = req.body.model || req.params.modelName || 'gemini-2.5-flash'
|
||||
@@ -676,6 +712,10 @@ async function handleStreamGenerateContent(req, res) {
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const { project, user_prompt_id, request: requestData } = req.body
|
||||
// 从路径参数或请求体中获取模型名
|
||||
const model = req.body.model || req.params.modelName || 'gemini-2.5-flash'
|
||||
|
||||
@@ -17,6 +17,12 @@ function createProxyAgent(proxy) {
|
||||
return ProxyHelper.createProxyAgent(proxy)
|
||||
}
|
||||
|
||||
// 检查 API Key 是否具备 OpenAI 权限
|
||||
function checkOpenAIPermissions(apiKeyData) {
|
||||
const permissions = apiKeyData?.permissions || 'all'
|
||||
return permissions === 'all' || permissions === 'openai'
|
||||
}
|
||||
|
||||
function normalizeHeaders(headers = {}) {
|
||||
if (!headers || typeof headers !== 'object') {
|
||||
return {}
|
||||
@@ -190,6 +196,19 @@ const handleResponses = async (req, res) => {
|
||||
// 从中间件获取 API Key 数据
|
||||
const apiKeyData = req.apiKey || {}
|
||||
|
||||
if (!checkOpenAIPermissions(apiKeyData)) {
|
||||
logger.security(
|
||||
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 OpenAI 权限,拒绝访问 ${req.originalUrl}`
|
||||
)
|
||||
return res.status(403).json({
|
||||
error: {
|
||||
message: 'This API key does not have permission to access OpenAI',
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 从请求头或请求体中提取会话 ID
|
||||
const sessionId =
|
||||
req.headers['session_id'] ||
|
||||
|
||||
@@ -10,6 +10,40 @@ const sessionHelper = require('../utils/sessionHelper')
|
||||
// 导入 geminiRoutes 中导出的处理函数
|
||||
const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes')
|
||||
|
||||
// 检查 API Key 是否具备 Gemini 权限
|
||||
function hasGeminiPermission(apiKeyData, requiredPermission = 'gemini') {
|
||||
const permissions = apiKeyData?.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 确保请求拥有 Gemini 权限
|
||||
function ensureGeminiPermission(req, res) {
|
||||
const apiKeyData = req.apiKey || {}
|
||||
if (hasGeminiPermission(apiKeyData, 'gemini')) {
|
||||
return true
|
||||
}
|
||||
|
||||
logger.security(
|
||||
`🚫 API Key ${apiKeyData.id || 'unknown'} 缺少 Gemini 权限,拒绝访问 ${req.originalUrl}`
|
||||
)
|
||||
|
||||
res.status(403).json({
|
||||
error: {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// 供路由中间件复用的权限检查
|
||||
function ensureGeminiPermissionMiddleware(req, res, next) {
|
||||
if (ensureGeminiPermission(req, res)) {
|
||||
return next()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 标准 Gemini API 路由处理器
|
||||
// 这些路由将挂载在 /gemini 路径下,处理标准 Gemini API 格式的请求
|
||||
// 标准格式: /gemini/v1beta/models/{model}:generateContent
|
||||
@@ -17,6 +51,10 @@ const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('
|
||||
// 专门处理标准 Gemini API 格式的 generateContent
|
||||
async function handleStandardGenerateContent(req, res) {
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 从路径参数中获取模型名
|
||||
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
@@ -225,6 +263,10 @@ async function handleStandardStreamGenerateContent(req, res) {
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
if (!ensureGeminiPermission(req, res)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 从路径参数中获取模型名
|
||||
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
@@ -535,31 +577,48 @@ async function handleStandardStreamGenerateContent(req, res) {
|
||||
}
|
||||
|
||||
// v1beta 版本的标准路由 - 支持动态模型名称
|
||||
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:loadCodeAssist',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
|
||||
handleLoadCodeAssist(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:onboardUser',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
|
||||
handleOnboardUser(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:countTokens',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request: ${req.method} ${req.originalUrl}`)
|
||||
handleCountTokens(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
// 使用专门的处理函数处理标准 Gemini API 格式
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:generateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardGenerateContent
|
||||
)
|
||||
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:streamGenerateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardStreamGenerateContent
|
||||
)
|
||||
|
||||
@@ -567,45 +626,52 @@ router.post(
|
||||
router.post(
|
||||
'/v1/models/:modelName\\:generateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardGenerateContent
|
||||
)
|
||||
|
||||
router.post(
|
||||
'/v1/models/:modelName\\:streamGenerateContent',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
handleStandardStreamGenerateContent
|
||||
)
|
||||
|
||||
router.post('/v1/models/:modelName\\:countTokens', authenticateApiKey, (req, res, next) => {
|
||||
router.post(
|
||||
'/v1/models/:modelName\\:countTokens',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
(req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1): ${req.method} ${req.originalUrl}`)
|
||||
handleCountTokens(req, res, next)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
// v1internal 版本的标准路由(这些使用原有的处理函数,因为格式不同)
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
handleLoadCodeAssist(req, res, next)
|
||||
})
|
||||
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
handleOnboardUser(req, res, next)
|
||||
})
|
||||
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
handleCountTokens(req, res, next)
|
||||
})
|
||||
|
||||
// v1internal 使用不同的处理逻辑,因为它们不包含模型在 URL 中
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
// v1internal 格式不同,使用原有的处理函数
|
||||
const { handleGenerateContent } = require('./geminiRoutes')
|
||||
handleGenerateContent(req, res, next)
|
||||
})
|
||||
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res, next) => {
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res, next) => {
|
||||
logger.info(`Standard Gemini API request (v1internal): ${req.method} ${req.originalUrl}`)
|
||||
// v1internal 格式不同,使用原有的处理函数
|
||||
const { handleStreamGenerateContent } = require('./geminiRoutes')
|
||||
@@ -613,32 +679,37 @@ router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, (req, res
|
||||
})
|
||||
|
||||
// 添加标准 Gemini API 的模型列表端点
|
||||
router.get('/v1beta/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
logger.info('Standard Gemini API models request')
|
||||
// 直接调用 geminiRoutes 中的模型处理逻辑
|
||||
const geminiRoutes = require('./geminiRoutes')
|
||||
const modelHandler = geminiRoutes.stack.find(
|
||||
(layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get
|
||||
)
|
||||
if (modelHandler && modelHandler.route.stack[1]) {
|
||||
// 调用处理函数(跳过第一个 authenticateApiKey 中间件)
|
||||
modelHandler.route.stack[1].handle(req, res)
|
||||
} else {
|
||||
res.status(500).json({ error: 'Models handler not found' })
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in standard models endpoint:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'api_error'
|
||||
router.get(
|
||||
'/v1beta/models',
|
||||
authenticateApiKey,
|
||||
ensureGeminiPermissionMiddleware,
|
||||
async (req, res) => {
|
||||
try {
|
||||
logger.info('Standard Gemini API models request')
|
||||
// 直接调用 geminiRoutes 中的模型处理逻辑
|
||||
const geminiRoutes = require('./geminiRoutes')
|
||||
const modelHandler = geminiRoutes.stack.find(
|
||||
(layer) => layer.route && layer.route.path === '/models' && layer.route.methods.get
|
||||
)
|
||||
if (modelHandler && modelHandler.route.stack[1]) {
|
||||
// 调用处理函数(跳过第一个 authenticateApiKey 中间件)
|
||||
modelHandler.route.stack[1].handle(req, res)
|
||||
} else {
|
||||
res.status(500).json({ error: 'Models handler not found' })
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in standard models endpoint:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'api_error'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
router.get('/v1/models', authenticateApiKey, ensureGeminiPermissionMiddleware, async (req, res) => {
|
||||
try {
|
||||
logger.info('Standard Gemini API models request (v1)')
|
||||
// 直接调用 geminiRoutes 中的模型处理逻辑
|
||||
@@ -663,7 +734,7 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
})
|
||||
|
||||
// 添加模型详情端点
|
||||
router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => {
|
||||
router.get('/v1beta/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => {
|
||||
const { modelName } = req.params
|
||||
logger.info(`Standard Gemini API model details request: ${modelName}`)
|
||||
|
||||
@@ -681,7 +752,7 @@ router.get('/v1beta/models/:modelName', authenticateApiKey, (req, res) => {
|
||||
})
|
||||
})
|
||||
|
||||
router.get('/v1/models/:modelName', authenticateApiKey, (req, res) => {
|
||||
router.get('/v1/models/:modelName', authenticateApiKey, ensureGeminiPermissionMiddleware, (req, res) => {
|
||||
const { modelName } = req.params
|
||||
logger.info(`Standard Gemini API model details request (v1): ${modelName}`)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user