mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 00:53:33 +00:00
feat: gemini 流式响应
This commit is contained in:
@@ -415,9 +415,27 @@ async function handleGenerateContent(req, res) {
|
|||||||
const { model, project, user_prompt_id, request: requestData } = req.body;
|
const { model, project, user_prompt_id, request: requestData } = req.body;
|
||||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||||
console.log(321, requestData);
|
console.log(321, requestData);
|
||||||
|
|
||||||
|
// 处理 OpenAI 格式请求(没有 request 字段的情况)
|
||||||
|
let actualRequestData = requestData;
|
||||||
|
if (!requestData && req.body.messages) {
|
||||||
|
// 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象
|
||||||
|
actualRequestData = {
|
||||||
|
contents: req.body.messages.map(msg => ({
|
||||||
|
role: msg.role === 'assistant' ? 'model' : msg.role,
|
||||||
|
parts: [{ text: msg.content }]
|
||||||
|
})),
|
||||||
|
generationConfig: {
|
||||||
|
temperature: req.body.temperature,
|
||||||
|
maxOutputTokens: req.body.max_tokens,
|
||||||
|
topP: req.body.top_p,
|
||||||
|
topK: req.body.top_k
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// 验证必需参数
|
// 验证必需参数
|
||||||
if (!requestData || !requestData.contents) {
|
if (!actualRequestData || !actualRequestData.contents) {
|
||||||
return res.status(400).json({
|
return res.status(400).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Request contents are required',
|
message: 'Request contents are required',
|
||||||
@@ -442,7 +460,7 @@ async function handleGenerateContent(req, res) {
|
|||||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||||
const response = await geminiAccountService.generateContent(
|
const response = await geminiAccountService.generateContent(
|
||||||
client,
|
client,
|
||||||
{ model, request: requestData },
|
{ model, request: actualRequestData },
|
||||||
user_prompt_id,
|
user_prompt_id,
|
||||||
project || account.projectId,
|
project || account.projectId,
|
||||||
req.apiKey?.id // 使用 API Key ID 作为 session ID
|
req.apiKey?.id // 使用 API Key ID 作为 session ID
|
||||||
@@ -469,8 +487,26 @@ async function handleStreamGenerateContent(req, res) {
|
|||||||
const { model, project, user_prompt_id, request: requestData } = req.body;
|
const { model, project, user_prompt_id, request: requestData } = req.body;
|
||||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||||
|
|
||||||
|
// 处理 OpenAI 格式请求(没有 request 字段的情况)
|
||||||
|
let actualRequestData = requestData;
|
||||||
|
if (!requestData && req.body.messages) {
|
||||||
|
// 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象
|
||||||
|
actualRequestData = {
|
||||||
|
contents: req.body.messages.map(msg => ({
|
||||||
|
role: msg.role === 'assistant' ? 'model' : msg.role,
|
||||||
|
parts: [{ text: msg.content }]
|
||||||
|
})),
|
||||||
|
generationConfig: {
|
||||||
|
temperature: req.body.temperature,
|
||||||
|
maxOutputTokens: req.body.max_tokens,
|
||||||
|
topP: req.body.top_p,
|
||||||
|
topK: req.body.top_k
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// 验证必需参数
|
// 验证必需参数
|
||||||
if (!requestData || !requestData.contents) {
|
if (!actualRequestData || !actualRequestData.contents) {
|
||||||
return res.status(400).json({
|
return res.status(400).json({
|
||||||
error: {
|
error: {
|
||||||
message: 'Request contents are required',
|
message: 'Request contents are required',
|
||||||
@@ -506,7 +542,7 @@ async function handleStreamGenerateContent(req, res) {
|
|||||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||||
const streamResponse = await geminiAccountService.generateContentStream(
|
const streamResponse = await geminiAccountService.generateContentStream(
|
||||||
client,
|
client,
|
||||||
{ model, request: requestData },
|
{ model, request: actualRequestData },
|
||||||
user_prompt_id,
|
user_prompt_id,
|
||||||
project || account.projectId,
|
project || account.projectId,
|
||||||
req.apiKey?.id, // 使用 API Key ID 作为 session ID
|
req.apiKey?.id, // 使用 API Key ID 作为 session ID
|
||||||
|
|||||||
Reference in New Issue
Block a user