Files
claude-relay-service/src/routes/geminiRoutes.js
2025-08-04 22:45:11 +08:00

578 lines
18 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

const express = require('express');
const router = express.Router();
const logger = require('../utils/logger');
const { authenticateApiKey } = require('../middleware/auth');
const geminiAccountService = require('../services/geminiAccountService');
const { sendGeminiRequest, getAvailableModels } = require('../services/geminiRelayService');
const crypto = require('crypto');
const sessionHelper = require('../utils/sessionHelper');
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler');
// const { OAuth2Client } = require('google-auth-library'); // OAuth2Client is not used in this file
// 生成会话哈希
function generateSessionHash(req) {
const sessionData = [
req.headers['user-agent'],
req.ip,
req.headers['x-api-key']?.substring(0, 10)
].filter(Boolean).join(':');
return crypto.createHash('sha256').update(sessionData).digest('hex');
}
// 检查 API Key 权限
function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
const permissions = apiKeyData.permissions || 'all';
return permissions === 'all' || permissions === requiredPermission;
}
// Gemini 消息处理端点
router.post('/messages', authenticateApiKey, async (req, res) => {
const startTime = Date.now();
let abortController = null;
try {
const apiKeyData = req.apiKey;
// 检查权限
if (!checkPermissions(apiKeyData, 'gemini')) {
return res.status(403).json({
error: {
message: 'This API key does not have permission to access Gemini',
type: 'permission_denied'
}
});
}
// 提取请求参数
const {
messages,
model = 'gemini-2.0-flash-exp',
temperature = 0.7,
max_tokens = 4096,
stream = false
} = req.body;
// 验证必需参数
if (!messages || !Array.isArray(messages) || messages.length === 0) {
return res.status(400).json({
error: {
message: 'Messages array is required',
type: 'invalid_request_error'
}
});
}
// 生成会话哈希用于粘性会话
const sessionHash = generateSessionHash(req);
// 使用统一调度选择可用的 Gemini 账户(传递请求的模型)
let accountId;
try {
const schedulerResult = await unifiedGeminiScheduler.selectAccountForApiKey(
apiKeyData,
sessionHash,
model // 传递请求的模型进行过滤
);
accountId = schedulerResult.accountId;
} catch (error) {
logger.error('Failed to select Gemini account:', error);
return res.status(503).json({
error: {
message: error.message || 'No available Gemini accounts',
type: 'service_unavailable'
}
});
}
// 获取账户详情
const account = await geminiAccountService.getAccount(accountId);
if (!account) {
return res.status(503).json({
error: {
message: 'Selected account not found',
type: 'service_unavailable'
}
});
}
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`);
// 标记账户被使用
await geminiAccountService.markAccountUsed(account.id);
// 创建中止控制器
abortController = new AbortController();
// 处理客户端断开连接
req.on('close', () => {
if (abortController && !abortController.signal.aborted) {
logger.info('Client disconnected, aborting Gemini request');
abortController.abort();
}
});
// 发送请求到 Gemini
const geminiResponse = await sendGeminiRequest({
messages,
model,
temperature,
maxTokens: max_tokens,
stream,
accessToken: account.accessToken,
proxy: account.proxy,
apiKeyId: apiKeyData.id,
signal: abortController.signal,
projectId: account.projectId,
accountId: account.id
});
if (stream) {
// 设置流式响应头
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.setHeader('X-Accel-Buffering', 'no');
// 流式传输响应
for await (const chunk of geminiResponse) {
if (abortController.signal.aborted) {
break;
}
res.write(chunk);
}
res.end();
} else {
// 非流式响应
res.json(geminiResponse);
}
const duration = Date.now() - startTime;
logger.info(`Gemini request completed in ${duration}ms`);
} catch (error) {
logger.error('Gemini request error:', error);
// 处理速率限制
if (error.status === 429) {
if (req.apiKey && req.account) {
await geminiAccountService.setAccountRateLimited(req.account.id, true);
}
}
// 返回错误响应
const status = error.status || 500;
const errorResponse = {
error: error.error || {
message: error.message || 'Internal server error',
type: 'api_error'
}
};
res.status(status).json(errorResponse);
} finally {
// 清理资源
if (abortController) {
abortController = null;
}
}
});
// 获取可用模型列表
router.get('/models', authenticateApiKey, async (req, res) => {
try {
const apiKeyData = req.apiKey;
// 检查权限
if (!checkPermissions(apiKeyData, 'gemini')) {
return res.status(403).json({
error: {
message: 'This API key does not have permission to access Gemini',
type: 'permission_denied'
}
});
}
// 选择账户获取模型列表
const account = await geminiAccountService.selectAvailableAccount(apiKeyData.id);
if (!account) {
// 返回默认模型列表
return res.json({
object: 'list',
data: [
{
id: 'gemini-2.0-flash-exp',
object: 'model',
created: Date.now() / 1000,
owned_by: 'google'
}
]
});
}
// 获取模型列表
const models = await getAvailableModels(account.accessToken, account.proxy);
res.json({
object: 'list',
data: models
});
} catch (error) {
logger.error('Failed to get Gemini models:', error);
res.status(500).json({
error: {
message: 'Failed to retrieve models',
type: 'api_error'
}
});
}
});
// 使用情况统计(与 Claude 共用)
router.get('/usage', authenticateApiKey, async (req, res) => {
try {
const usage = req.apiKey.usage;
res.json({
object: 'usage',
total_tokens: usage.total.tokens,
total_requests: usage.total.requests,
daily_tokens: usage.daily.tokens,
daily_requests: usage.daily.requests,
monthly_tokens: usage.monthly.tokens,
monthly_requests: usage.monthly.requests
});
} catch (error) {
logger.error('Failed to get usage stats:', error);
res.status(500).json({
error: {
message: 'Failed to retrieve usage statistics',
type: 'api_error'
}
});
}
});
// API Key 信息(与 Claude 共用)
router.get('/key-info', authenticateApiKey, async (req, res) => {
try {
const keyData = req.apiKey;
res.json({
id: keyData.id,
name: keyData.name,
permissions: keyData.permissions || 'all',
token_limit: keyData.tokenLimit,
tokens_used: keyData.usage.total.tokens,
tokens_remaining: keyData.tokenLimit > 0
? Math.max(0, keyData.tokenLimit - keyData.usage.total.tokens)
: null,
rate_limit: {
window: keyData.rateLimitWindow,
requests: keyData.rateLimitRequests
},
concurrency_limit: keyData.concurrencyLimit,
model_restrictions: {
enabled: keyData.enableModelRestriction,
models: keyData.restrictedModels
}
});
} catch (error) {
logger.error('Failed to get key info:', error);
res.status(500).json({
error: {
message: 'Failed to retrieve API key information',
type: 'api_error'
}
});
}
});
// 共用的 loadCodeAssist 处理函数
async function handleLoadCodeAssist(req, res) {
try {
const sessionHash = sessionHelper.generateSessionHash(req.body);
// 使用统一调度选择账号(传递请求的模型)
const requestedModel = req.body.model;
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
logger.info(`accessToken: ${accessToken}`);
const { metadata, cloudaicompanionProject } = req.body;
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.info(`LoadCodeAssist request (${version})`, {
metadata: metadata || {},
cloudaicompanionProject: cloudaicompanionProject || null,
apiKeyId: req.apiKey?.id || 'unknown'
});
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject);
res.json(response);
} catch (error) {
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.error(`Error in loadCodeAssist endpoint (${version})`, { error: error.message });
res.status(500).json({
error: 'Internal server error',
message: error.message
});
}
}
// 共用的 onboardUser 处理函数
async function handleOnboardUser(req, res) {
try {
const { tierId, cloudaicompanionProject, metadata } = req.body;
const sessionHash = sessionHelper.generateSessionHash(req.body);
// 使用统一调度选择账号(传递请求的模型)
const requestedModel = req.body.model;
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.info(`OnboardUser request (${version})`, {
tierId: tierId || 'not provided',
cloudaicompanionProject: cloudaicompanionProject || null,
metadata: metadata || {},
apiKeyId: req.apiKey?.id || 'unknown'
});
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
// 如果提供了完整参数直接调用onboardUser
if (tierId && metadata) {
const response = await geminiAccountService.onboardUser(client, tierId, cloudaicompanionProject, metadata);
res.json(response);
} else {
// 否则执行完整的setupUser流程
const response = await geminiAccountService.setupUser(client, cloudaicompanionProject, metadata);
res.json(response);
}
} catch (error) {
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.error(`Error in onboardUser endpoint (${version})`, { error: error.message });
res.status(500).json({
error: 'Internal server error',
message: error.message
});
}
}
// 共用的 countTokens 处理函数
async function handleCountTokens(req, res) {
try {
// 处理请求体结构,支持直接 contents 或 request.contents
const requestData = req.body.request || req.body;
const { contents, model = 'gemini-2.0-flash-exp' } = requestData;
const sessionHash = sessionHelper.generateSessionHash(req.body);
// 验证必需参数
if (!contents || !Array.isArray(contents)) {
return res.status(400).json({
error: {
message: 'Contents array is required',
type: 'invalid_request_error'
}
});
}
// 使用统一调度选择账号
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.info(`CountTokens request (${version})`, {
model: model,
contentsLength: contents.length,
apiKeyId: req.apiKey?.id || 'unknown'
});
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
const response = await geminiAccountService.countTokens(client, contents, model);
res.json(response);
} catch (error) {
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.error(`Error in countTokens endpoint (${version})`, { error: error.message });
res.status(500).json({
error: {
message: error.message || 'Internal server error',
type: 'api_error'
}
});
}
}
// 共用的 generateContent 处理函数
async function handleGenerateContent(req, res) {
try {
const { model, project, user_prompt_id, request: requestData } = req.body;
const sessionHash = sessionHelper.generateSessionHash(req.body);
// 验证必需参数
if (!requestData || !requestData.contents) {
return res.status(400).json({
error: {
message: 'Request contents are required',
type: 'invalid_request_error'
}
});
}
// 使用统一调度选择账号
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
const account = await geminiAccountService.getAccount(accountId);
const { accessToken, refreshToken } = account;
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.info(`GenerateContent request (${version})`, {
model: model,
userPromptId: user_prompt_id,
projectId: project || account.projectId,
apiKeyId: req.apiKey?.id || 'unknown'
});
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
const response = await geminiAccountService.generateContent(
client,
{ model, request: requestData },
user_prompt_id,
project || account.projectId,
req.apiKey?.id // 使用 API Key ID 作为 session ID
);
res.json(response);
} catch (error) {
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.error(`Error in generateContent endpoint (${version})`, { error: error.message });
res.status(500).json({
error: {
message: error.message || 'Internal server error',
type: 'api_error'
}
});
}
}
// 共用的 streamGenerateContent 处理函数
async function handleStreamGenerateContent(req, res) {
let abortController = null;
try {
const { model, project, user_prompt_id, request: requestData } = req.body;
const sessionHash = sessionHelper.generateSessionHash(req.body);
// 验证必需参数
if (!requestData || !requestData.contents) {
return res.status(400).json({
error: {
message: 'Request contents are required',
type: 'invalid_request_error'
}
});
}
// 使用统一调度选择账号
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
const account = await geminiAccountService.getAccount(accountId);
const { accessToken, refreshToken } = account;
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.info(`StreamGenerateContent request (${version})`, {
model: model,
userPromptId: user_prompt_id,
projectId: project || account.projectId,
apiKeyId: req.apiKey?.id || 'unknown'
});
// 创建中止控制器
abortController = new AbortController();
// 处理客户端断开连接
req.on('close', () => {
if (abortController && !abortController.signal.aborted) {
logger.info('Client disconnected, aborting stream request');
abortController.abort();
}
});
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
const streamResponse = await geminiAccountService.generateContentStream(
client,
{ model, request: requestData },
user_prompt_id,
project || account.projectId,
req.apiKey?.id, // 使用 API Key ID 作为 session ID
abortController.signal // 传递中止信号
);
// 设置 SSE 响应头
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
res.setHeader('X-Accel-Buffering', 'no');
// 直接管道转发流式响应,不进行额外处理
streamResponse.pipe(res, { end: false });
streamResponse.on('end', () => {
logger.info('Stream completed successfully');
res.end();
});
streamResponse.on('error', (error) => {
logger.error('Stream error:', error);
if (!res.headersSent) {
res.status(500).json({
error: {
message: error.message || 'Stream error',
type: 'api_error'
}
});
} else {
res.end();
}
});
} catch (error) {
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
logger.error(`Error in streamGenerateContent endpoint (${version})`, { error: error.message });
if (!res.headersSent) {
res.status(500).json({
error: {
message: error.message || 'Internal server error',
type: 'api_error'
}
});
}
} finally {
// 清理资源
if (abortController) {
abortController = null;
}
}
}
// 注册所有路由端点
// v1internal 版本的端点
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist);
router.post('/v1internal\\:onboardUser', authenticateApiKey, handleOnboardUser);
router.post('/v1internal\\:countTokens', authenticateApiKey, handleCountTokens);
router.post('/v1internal\\:generateContent', authenticateApiKey, handleGenerateContent);
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent);
// v1beta 版本的端点 - 支持动态模型名称
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist);
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, handleOnboardUser);
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, handleCountTokens);
router.post('/v1beta/models/:modelName\\:generateContent', authenticateApiKey, handleGenerateContent);
router.post('/v1beta/models/:modelName\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent);
module.exports = router;