mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-22 16:43:35 +00:00
Merge branch 'pr/gemini-ratelimit' into dev
This commit is contained in:
@@ -6,7 +6,6 @@ const geminiAccountService = require('../services/geminiAccountService')
|
|||||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler')
|
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler')
|
||||||
const apiKeyService = require('../services/apiKeyService')
|
const apiKeyService = require('../services/apiKeyService')
|
||||||
const sessionHelper = require('../utils/sessionHelper')
|
const sessionHelper = require('../utils/sessionHelper')
|
||||||
const { parseSSELine } = require('../utils/sseParser')
|
|
||||||
|
|
||||||
// 导入 geminiRoutes 中导出的处理函数
|
// 导入 geminiRoutes 中导出的处理函数
|
||||||
const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes')
|
const { handleLoadCodeAssist, handleOnboardUser, handleCountTokens } = require('./geminiRoutes')
|
||||||
@@ -135,6 +134,9 @@ async function normalizeAxiosStreamError(error) {
|
|||||||
|
|
||||||
// 专门处理标准 Gemini API 格式的 generateContent
|
// 专门处理标准 Gemini API 格式的 generateContent
|
||||||
async function handleStandardGenerateContent(req, res) {
|
async function handleStandardGenerateContent(req, res) {
|
||||||
|
let account = null
|
||||||
|
let sessionHash = null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (!ensureGeminiPermission(req, res)) {
|
if (!ensureGeminiPermission(req, res)) {
|
||||||
return undefined
|
return undefined
|
||||||
@@ -142,7 +144,7 @@ async function handleStandardGenerateContent(req, res) {
|
|||||||
|
|
||||||
// 从路径参数中获取模型名
|
// 从路径参数中获取模型名
|
||||||
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
||||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||||
|
|
||||||
// 标准 Gemini API 请求体直接包含 contents 等字段
|
// 标准 Gemini API 请求体直接包含 contents 等字段
|
||||||
const { contents, generationConfig, safetySettings, systemInstruction, tools, toolConfig } =
|
const { contents, generationConfig, safetySettings, systemInstruction, tools, toolConfig } =
|
||||||
@@ -213,7 +215,7 @@ async function handleStandardGenerateContent(req, res) {
|
|||||||
sessionHash,
|
sessionHash,
|
||||||
model
|
model
|
||||||
)
|
)
|
||||||
const account = await geminiAccountService.getAccount(accountId)
|
account = await geminiAccountService.getAccount(accountId)
|
||||||
const { accessToken, refreshToken } = account
|
const { accessToken, refreshToken } = account
|
||||||
|
|
||||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1'
|
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1'
|
||||||
@@ -323,6 +325,17 @@ async function handleStandardGenerateContent(req, res) {
|
|||||||
responseData: error.response?.data,
|
responseData: error.response?.data,
|
||||||
stack: error.stack
|
stack: error.stack
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 处理速率限制
|
||||||
|
if (error.response?.status === 429) {
|
||||||
|
logger.warn(`⚠️ Gemini account ${account.id} rate limited (Standard API), marking as limited`)
|
||||||
|
try {
|
||||||
|
await unifiedGeminiScheduler.markAccountRateLimited(account.id, 'gemini', sessionHash)
|
||||||
|
} catch (limitError) {
|
||||||
|
logger.warn('Failed to mark account as rate limited in scheduler:', limitError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
res.status(500).json({
|
res.status(500).json({
|
||||||
error: {
|
error: {
|
||||||
message: error.message || 'Internal server error',
|
message: error.message || 'Internal server error',
|
||||||
@@ -335,6 +348,8 @@ async function handleStandardGenerateContent(req, res) {
|
|||||||
// 专门处理标准 Gemini API 格式的 streamGenerateContent
|
// 专门处理标准 Gemini API 格式的 streamGenerateContent
|
||||||
async function handleStandardStreamGenerateContent(req, res) {
|
async function handleStandardStreamGenerateContent(req, res) {
|
||||||
let abortController = null
|
let abortController = null
|
||||||
|
let account = null
|
||||||
|
let sessionHash = null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (!ensureGeminiPermission(req, res)) {
|
if (!ensureGeminiPermission(req, res)) {
|
||||||
@@ -343,7 +358,7 @@ async function handleStandardStreamGenerateContent(req, res) {
|
|||||||
|
|
||||||
// 从路径参数中获取模型名
|
// 从路径参数中获取模型名
|
||||||
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
const model = req.params.modelName || 'gemini-2.0-flash-exp'
|
||||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||||
|
|
||||||
// 标准 Gemini API 请求体直接包含 contents 等字段
|
// 标准 Gemini API 请求体直接包含 contents 等字段
|
||||||
const { contents, generationConfig, safetySettings, systemInstruction, tools, toolConfig } =
|
const { contents, generationConfig, safetySettings, systemInstruction, tools, toolConfig } =
|
||||||
@@ -414,7 +429,7 @@ async function handleStandardStreamGenerateContent(req, res) {
|
|||||||
sessionHash,
|
sessionHash,
|
||||||
model
|
model
|
||||||
)
|
)
|
||||||
const account = await geminiAccountService.getAccount(accountId)
|
account = await geminiAccountService.getAccount(accountId)
|
||||||
const { accessToken, refreshToken } = account
|
const { accessToken, refreshToken } = account
|
||||||
|
|
||||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1'
|
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1'
|
||||||
@@ -511,7 +526,6 @@ async function handleStandardStreamGenerateContent(req, res) {
|
|||||||
|
|
||||||
// 处理流式响应并捕获usage数据
|
// 处理流式响应并捕获usage数据
|
||||||
// 方案 A++:透明转发 + 异步 usage 提取 + SSE 心跳机制
|
// 方案 A++:透明转发 + 异步 usage 提取 + SSE 心跳机制
|
||||||
let streamBuffer = '' // 缓冲区用于处理不完整的行
|
|
||||||
let totalUsage = {
|
let totalUsage = {
|
||||||
promptTokenCount: 0,
|
promptTokenCount: 0,
|
||||||
candidatesTokenCount: 0,
|
candidatesTokenCount: 0,
|
||||||
@@ -538,55 +552,61 @@ async function handleStandardStreamGenerateContent(req, res) {
|
|||||||
// 更新最后数据时间
|
// 更新最后数据时间
|
||||||
lastDataTime = Date.now()
|
lastDataTime = Date.now()
|
||||||
|
|
||||||
// 1️⃣ 立即转发原始数据(零延迟,最高优先级)
|
const chunkStr = chunk.toString()
|
||||||
|
|
||||||
|
// 尝试解析 SSE 数据
|
||||||
|
// upstream 返回格式: data: {"response": {...}}
|
||||||
|
// standard API 期望格式: data: {...}
|
||||||
|
|
||||||
|
let processedChunk = chunk
|
||||||
|
|
||||||
|
if (chunkStr.startsWith('data: ')) {
|
||||||
|
try {
|
||||||
|
const jsonStr = chunkStr.substring(6).trim()
|
||||||
|
if (jsonStr !== '[DONE]') {
|
||||||
|
const data = JSON.parse(jsonStr)
|
||||||
|
if (data.response) {
|
||||||
|
// 提取内部的 response 对象并重新包装为 SSE
|
||||||
|
const newPayload = JSON.stringify(data.response)
|
||||||
|
processedChunk = Buffer.from(`data: ${newPayload}\n\n`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// 解析失败,直接转发原始数据
|
||||||
|
// logger.warn('Failed to parse SSE chunk:', e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1️⃣ 立即转发处理后的数据
|
||||||
if (!res.destroyed) {
|
if (!res.destroyed) {
|
||||||
res.write(chunk) // 直接转发 Buffer,无需转换和序列化
|
res.write(processedChunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2️⃣ 异步提取 usage 数据(不阻塞转发)
|
// 2️⃣ 异步提取 usage 数据(不阻塞转发)
|
||||||
// 使用 setImmediate 将解析放到下一个事件循环
|
|
||||||
setImmediate(() => {
|
setImmediate(() => {
|
||||||
try {
|
try {
|
||||||
const chunkStr = chunk.toString()
|
const str = processedChunk.toString()
|
||||||
if (!chunkStr.trim()) {
|
if (!str.trim() || !str.includes('usageMetadata')) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 快速检查是否包含 usage 数据(避免不必要的解析)
|
// 简单的解析尝试
|
||||||
if (!chunkStr.includes('usageMetadata')) {
|
const match = str.match(/"usageMetadata":\s*({[^}]+})/)
|
||||||
return
|
if (match && match[1]) {
|
||||||
}
|
|
||||||
|
|
||||||
// 处理不完整的行
|
|
||||||
streamBuffer += chunkStr
|
|
||||||
const lines = streamBuffer.split('\n')
|
|
||||||
streamBuffer = lines.pop() || ''
|
|
||||||
|
|
||||||
// 仅解析包含 usage 的行
|
|
||||||
for (const line of lines) {
|
|
||||||
if (!line.trim() || !line.includes('usageMetadata')) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const parsed = parseSSELine(line)
|
const usage = JSON.parse(match[1])
|
||||||
if (parsed.type === 'data' && parsed.data.response?.usageMetadata) {
|
totalUsage = usage
|
||||||
totalUsage = parsed.data.response.usageMetadata
|
logger.debug('📊 Captured Gemini usage data:', totalUsage)
|
||||||
logger.debug('📊 Captured Gemini usage data:', totalUsage)
|
} catch (e) {
|
||||||
}
|
// ignore
|
||||||
} catch (parseError) {
|
|
||||||
// 解析失败但不影响转发
|
|
||||||
logger.warn('⚠️ Failed to parse usage line:', parseError.message)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// 提取失败但不影响转发
|
|
||||||
logger.warn('⚠️ Error extracting usage data:', error.message)
|
logger.warn('⚠️ Error extracting usage data:', error.message)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error processing stream chunk:', error)
|
logger.error('Error processing stream chunk:', error)
|
||||||
// 不中断流,继续处理后续数据
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -682,6 +702,18 @@ async function handleStandardStreamGenerateContent(req, res) {
|
|||||||
stack: error.stack
|
stack: error.stack
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 处理速率限制
|
||||||
|
if (error.response?.status === 429) {
|
||||||
|
logger.warn(
|
||||||
|
`⚠️ Gemini account ${account.id} rate limited (Standard Stream API), marking as limited`
|
||||||
|
)
|
||||||
|
try {
|
||||||
|
await unifiedGeminiScheduler.markAccountRateLimited(account.id, 'gemini', sessionHash)
|
||||||
|
} catch (limitError) {
|
||||||
|
logger.warn('Failed to mark account as rate limited in scheduler:', limitError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!res.headersSent) {
|
if (!res.headersSent) {
|
||||||
const statusCode = normalizedError.status || 500
|
const statusCode = normalizedError.status || 500
|
||||||
const responseBody = {
|
const responseBody = {
|
||||||
|
|||||||
Reference in New Issue
Block a user