diff --git a/src/routes/geminiRoutes.js b/src/routes/geminiRoutes.js index c5d706a3..2a1525f2 100644 --- a/src/routes/geminiRoutes.js +++ b/src/routes/geminiRoutes.js @@ -50,7 +50,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => { // 提取请求参数 const { messages, - model = 'gemini-2.0-flash-exp', + model = 'gemini-2.5-flash', temperature = 0.7, max_tokens = 4096, stream = false @@ -217,7 +217,7 @@ router.get('/models', authenticateApiKey, async (req, res) => { object: 'list', data: [ { - id: 'gemini-2.0-flash-exp', + id: 'gemini-2.5-flash', object: 'model', created: Date.now() / 1000, owned_by: 'google' @@ -311,8 +311,8 @@ async function handleLoadCodeAssist(req, res) { try { const sessionHash = sessionHelper.generateSessionHash(req.body) - // 使用统一调度选择账号(传递请求的模型) - const requestedModel = req.body.model + // 从路径参数或请求体中获取模型名 + const requestedModel = req.body.model || req.params.modelName || 'gemini-2.5-flash' const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey( req.apiKey, sessionHash, @@ -368,8 +368,8 @@ async function handleOnboardUser(req, res) { const { tierId, cloudaicompanionProject, metadata } = req.body const sessionHash = sessionHelper.generateSessionHash(req.body) - // 使用统一调度选择账号(传递请求的模型) - const requestedModel = req.body.model + // 从路径参数或请求体中获取模型名 + const requestedModel = req.body.model || req.params.modelName || 'gemini-2.5-flash' const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey( req.apiKey, sessionHash, @@ -439,7 +439,9 @@ async function handleCountTokens(req, res) { try { // 处理请求体结构,支持直接 contents 或 request.contents const requestData = req.body.request || req.body - const { contents, model = 'gemini-2.0-flash-exp' } = requestData + const { contents } = requestData + // 从路径参数或请求体中获取模型名 + const model = requestData.model || req.params.modelName || 'gemini-2.5-flash' const sessionHash = sessionHelper.generateSessionHash(req.body) // 验证必需参数 @@ -487,7 +489,9 @@ async function handleCountTokens(req, res) { // 共用的 generateContent 处理函数 async function handleGenerateContent(req, res) { try { - const { model, project, user_prompt_id, request: requestData } = req.body + const { project, user_prompt_id, request: requestData } = req.body + // 从路径参数或请求体中获取模型名 + const model = req.body.model || req.params.modelName || 'gemini-2.5-flash' const sessionHash = sessionHelper.generateSessionHash(req.body) // 处理不同格式的请求 @@ -582,7 +586,7 @@ async function handleGenerateContent(req, res) { } } - res.json(response) + res.json(version === 'v1beta' ? response.response : response) } catch (error) { const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal' // 打印详细的错误信息 @@ -610,7 +614,9 @@ async function handleStreamGenerateContent(req, res) { let abortController = null try { - const { model, project, user_prompt_id, request: requestData } = req.body + const { project, user_prompt_id, request: requestData } = req.body + // 从路径参数或请求体中获取模型名 + const model = req.body.model || req.params.modelName || 'gemini-2.5-flash' const sessionHash = sessionHelper.generateSessionHash(req.body) // 处理不同格式的请求 @@ -702,8 +708,28 @@ async function handleStreamGenerateContent(req, res) { res.setHeader('Connection', 'keep-alive') res.setHeader('X-Accel-Buffering', 'no') + // SSE 解析函数 + const parseSSELine = (line) => { + if (!line.startsWith('data: ')) { + return { type: 'other', line, data: null } + } + + const jsonStr = line.substring(6).trim() + + if (!jsonStr || jsonStr === '[DONE]') { + return { type: 'control', line, data: null, jsonStr } + } + + try { + const data = JSON.parse(jsonStr) + return { type: 'data', line, data, jsonStr } + } catch (e) { + return { type: 'invalid', line, data: null, jsonStr, error: e } + } + } + // 处理流式响应并捕获usage数据 - let buffer = '' + let streamBuffer = '' // 统一的流处理缓冲区 let totalUsage = { promptTokenCount: 0, candidatesTokenCount: 0, @@ -715,32 +741,60 @@ async function handleStreamGenerateContent(req, res) { try { const chunkStr = chunk.toString() - // 直接转发数据到客户端 - if (!res.destroyed) { - res.write(chunkStr) + if (!chunkStr.trim()) { + return } - // 同时解析数据以捕获usage信息 - buffer += chunkStr - const lines = buffer.split('\n') - buffer = lines.pop() || '' + // 使用统一缓冲区处理不完整的行 + streamBuffer += chunkStr + const lines = streamBuffer.split('\n') + streamBuffer = lines.pop() || '' // 保留最后一个不完整的行 + + const processedLines = [] for (const line of lines) { - if (line.startsWith('data: ') && line.length > 6) { - try { - const jsonStr = line.slice(6) - if (jsonStr && jsonStr !== '[DONE]') { - const data = JSON.parse(jsonStr) + if (!line.trim()) { + continue // 跳过空行,不添加到处理队列 + } - // 从响应中提取usage数据 - if (data.response?.usageMetadata) { - totalUsage = data.response.usageMetadata - logger.debug('📊 Captured Gemini usage data:', totalUsage) - } + // 解析 SSE 行 + const parsed = parseSSELine(line) + + // 提取 usage 数据(适用于所有版本) + if (parsed.type === 'data' && parsed.data.response?.usageMetadata) { + totalUsage = parsed.data.response.usageMetadata + logger.debug('📊 Captured Gemini usage data:', totalUsage) + } + + // 根据版本处理输出 + if (version === 'v1beta') { + if (parsed.type === 'data') { + if (parsed.data.response) { + // 有 response 字段,只返回 response 的内容 + processedLines.push(`data: ${JSON.stringify(parsed.data.response)}`) + } else { + // 没有 response 字段,返回整个数据对象 + processedLines.push(`data: ${JSON.stringify(parsed.data)}`) } - } catch (e) { - // 忽略解析错误 + } else if (parsed.type === 'control') { + // 控制消息(如 [DONE])保持原样 + processedLines.push(line) } + // 跳过其他类型的行('other', 'invalid') + } + } + + // 发送数据到客户端 + if (version === 'v1beta') { + for (const line of processedLines) { + if (!res.destroyed) { + res.write(`${line}\n\n`) + } + } + } else { + // v1internal 直接转发原始数据 + if (!res.destroyed) { + res.write(chunkStr) } } } catch (error) {