mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-23 09:38:02 +00:00
refactor: standardize code formatting and linting configuration
- Replace .eslintrc.js with .eslintrc.cjs for better ES module compatibility - Add .prettierrc configuration for consistent code formatting - Update package.json with new lint and format scripts - Add nodemon.json for development hot reloading configuration - Standardize code formatting across all JavaScript and Vue files - Update web admin SPA with improved linting rules and formatting - Add prettier configuration to web admin SPA 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
3846
src/routes/admin.js
3846
src/routes/admin.js
File diff suppressed because it is too large
Load Diff
@@ -1,360 +1,482 @@
|
||||
const express = require('express');
|
||||
const claudeRelayService = require('../services/claudeRelayService');
|
||||
const claudeConsoleRelayService = require('../services/claudeConsoleRelayService');
|
||||
const bedrockRelayService = require('../services/bedrockRelayService');
|
||||
const bedrockAccountService = require('../services/bedrockAccountService');
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const logger = require('../utils/logger');
|
||||
const redis = require('../models/redis');
|
||||
const sessionHelper = require('../utils/sessionHelper');
|
||||
const express = require('express')
|
||||
const claudeRelayService = require('../services/claudeRelayService')
|
||||
const claudeConsoleRelayService = require('../services/claudeConsoleRelayService')
|
||||
const bedrockRelayService = require('../services/bedrockRelayService')
|
||||
const bedrockAccountService = require('../services/bedrockAccountService')
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const logger = require('../utils/logger')
|
||||
const redis = require('../models/redis')
|
||||
const sessionHelper = require('../utils/sessionHelper')
|
||||
|
||||
const router = express.Router();
|
||||
const router = express.Router()
|
||||
|
||||
// 🔧 共享的消息处理函数
|
||||
async function handleMessagesRequest(req, res) {
|
||||
try {
|
||||
const startTime = Date.now();
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 严格的输入验证
|
||||
if (!req.body || typeof req.body !== 'object') {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid request',
|
||||
message: 'Request body must be a valid JSON object'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
if (!req.body.messages || !Array.isArray(req.body.messages)) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid request',
|
||||
message: 'Missing or invalid field: messages (must be an array)'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
if (req.body.messages.length === 0) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid request',
|
||||
message: 'Messages array cannot be empty'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 检查是否为流式请求
|
||||
const isStream = req.body.stream === true;
|
||||
|
||||
logger.api(`🚀 Processing ${isStream ? 'stream' : 'non-stream'} request for key: ${req.apiKey.name}`);
|
||||
const isStream = req.body.stream === true
|
||||
|
||||
logger.api(
|
||||
`🚀 Processing ${isStream ? 'stream' : 'non-stream'} request for key: ${req.apiKey.name}`
|
||||
)
|
||||
|
||||
if (isStream) {
|
||||
// 流式响应 - 只使用官方真实usage数据
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no'); // 禁用 Nginx 缓冲
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||
res.setHeader('X-Accel-Buffering', 'no') // 禁用 Nginx 缓冲
|
||||
|
||||
// 禁用 Nagle 算法,确保数据立即发送
|
||||
if (res.socket && typeof res.socket.setNoDelay === 'function') {
|
||||
res.socket.setNoDelay(true);
|
||||
res.socket.setNoDelay(true)
|
||||
}
|
||||
|
||||
|
||||
// 流式响应不需要额外处理,中间件已经设置了监听器
|
||||
|
||||
let usageDataCaptured = false;
|
||||
|
||||
|
||||
let usageDataCaptured = false
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
|
||||
const requestedModel = req.body.model
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
|
||||
// 根据账号类型选择对应的转发服务并调用
|
||||
if (accountType === 'claude-official') {
|
||||
// 官方Claude账号使用原有的转发服务(会自己选择账号)
|
||||
await claudeRelayService.relayStreamRequestWithUsageCapture(req.body, req.apiKey, res, req.headers, (usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info('🎯 Usage callback triggered with complete data:', JSON.stringify(usageData, null, 2));
|
||||
|
||||
if (usageData && usageData.input_tokens !== undefined && usageData.output_tokens !== undefined) {
|
||||
const inputTokens = usageData.input_tokens || 0;
|
||||
const outputTokens = usageData.output_tokens || 0;
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0;
|
||||
const model = usageData.model || 'unknown';
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const accountId = usageData.accountId;
|
||||
apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model, accountId).catch(error => {
|
||||
logger.error('❌ Failed to record stream usage:', error);
|
||||
});
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens).catch(error => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error);
|
||||
});
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
await claudeRelayService.relayStreamRequestWithUsageCapture(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
res,
|
||||
req.headers,
|
||||
(usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info(
|
||||
'🎯 Usage callback triggered with complete data:',
|
||||
JSON.stringify(usageData, null, 2)
|
||||
)
|
||||
|
||||
if (
|
||||
usageData &&
|
||||
usageData.input_tokens !== undefined &&
|
||||
usageData.output_tokens !== undefined
|
||||
) {
|
||||
const inputTokens = usageData.input_tokens || 0
|
||||
const outputTokens = usageData.output_tokens || 0
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0
|
||||
const model = usageData.model || 'unknown'
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const { accountId: usageAccountId } = usageData
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
req.apiKey.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
usageAccountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record stream usage:', error)
|
||||
})
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
redis
|
||||
.getClient()
|
||||
.incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error)
|
||||
})
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageDataCaptured = true
|
||||
logger.api(
|
||||
`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`
|
||||
)
|
||||
} else {
|
||||
logger.warn(
|
||||
'⚠️ Usage callback triggered but data is incomplete:',
|
||||
JSON.stringify(usageData)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
usageDataCaptured = true;
|
||||
logger.api(`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`);
|
||||
} else {
|
||||
logger.warn('⚠️ Usage callback triggered but data is incomplete:', JSON.stringify(usageData));
|
||||
}
|
||||
});
|
||||
)
|
||||
} else if (accountType === 'claude-console') {
|
||||
// Claude Console账号使用Console转发服务(需要传递accountId)
|
||||
await claudeConsoleRelayService.relayStreamRequestWithUsageCapture(req.body, req.apiKey, res, req.headers, (usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info('🎯 Usage callback triggered with complete data:', JSON.stringify(usageData, null, 2));
|
||||
|
||||
if (usageData && usageData.input_tokens !== undefined && usageData.output_tokens !== undefined) {
|
||||
const inputTokens = usageData.input_tokens || 0;
|
||||
const outputTokens = usageData.output_tokens || 0;
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0;
|
||||
const model = usageData.model || 'unknown';
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const usageAccountId = usageData.accountId;
|
||||
apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model, usageAccountId).catch(error => {
|
||||
logger.error('❌ Failed to record stream usage:', error);
|
||||
});
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens).catch(error => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error);
|
||||
});
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
await claudeConsoleRelayService.relayStreamRequestWithUsageCapture(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
res,
|
||||
req.headers,
|
||||
(usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info(
|
||||
'🎯 Usage callback triggered with complete data:',
|
||||
JSON.stringify(usageData, null, 2)
|
||||
)
|
||||
|
||||
if (
|
||||
usageData &&
|
||||
usageData.input_tokens !== undefined &&
|
||||
usageData.output_tokens !== undefined
|
||||
) {
|
||||
const inputTokens = usageData.input_tokens || 0
|
||||
const outputTokens = usageData.output_tokens || 0
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0
|
||||
const model = usageData.model || 'unknown'
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const usageAccountId = usageData.accountId
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
req.apiKey.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
usageAccountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record stream usage:', error)
|
||||
})
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
redis
|
||||
.getClient()
|
||||
.incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error)
|
||||
})
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageDataCaptured = true
|
||||
logger.api(
|
||||
`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`
|
||||
)
|
||||
} else {
|
||||
logger.warn(
|
||||
'⚠️ Usage callback triggered but data is incomplete:',
|
||||
JSON.stringify(usageData)
|
||||
)
|
||||
}
|
||||
|
||||
usageDataCaptured = true;
|
||||
logger.api(`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`);
|
||||
} else {
|
||||
logger.warn('⚠️ Usage callback triggered but data is incomplete:', JSON.stringify(usageData));
|
||||
}
|
||||
}, accountId);
|
||||
},
|
||||
accountId
|
||||
)
|
||||
} else if (accountType === 'bedrock') {
|
||||
// Bedrock账号使用Bedrock转发服务
|
||||
try {
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId);
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId)
|
||||
if (!bedrockAccountResult.success) {
|
||||
throw new Error('Failed to get Bedrock account details');
|
||||
throw new Error('Failed to get Bedrock account details')
|
||||
}
|
||||
|
||||
const result = await bedrockRelayService.handleStreamRequest(req.body, bedrockAccountResult.data, res);
|
||||
|
||||
const result = await bedrockRelayService.handleStreamRequest(
|
||||
req.body,
|
||||
bedrockAccountResult.data,
|
||||
res
|
||||
)
|
||||
|
||||
// 记录Bedrock使用统计
|
||||
if (result.usage) {
|
||||
const inputTokens = result.usage.input_tokens || 0;
|
||||
const outputTokens = result.usage.output_tokens || 0;
|
||||
|
||||
apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, 0, 0, result.model, accountId).catch(error => {
|
||||
logger.error('❌ Failed to record Bedrock stream usage:', error);
|
||||
});
|
||||
|
||||
const inputTokens = result.usage.input_tokens || 0
|
||||
const outputTokens = result.usage.output_tokens || 0
|
||||
|
||||
apiKeyService
|
||||
.recordUsage(req.apiKey.id, inputTokens, outputTokens, 0, 0, result.model, accountId)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record Bedrock stream usage:', error)
|
||||
})
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens;
|
||||
redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens).catch(error => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error);
|
||||
});
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
const totalTokens = inputTokens + outputTokens
|
||||
redis
|
||||
.getClient()
|
||||
.incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error)
|
||||
})
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageDataCaptured = true;
|
||||
logger.api(`📊 Bedrock stream usage recorded - Model: ${result.model}, Input: ${inputTokens}, Output: ${outputTokens}, Total: ${inputTokens + outputTokens} tokens`);
|
||||
|
||||
usageDataCaptured = true
|
||||
logger.api(
|
||||
`📊 Bedrock stream usage recorded - Model: ${result.model}, Input: ${inputTokens}, Output: ${outputTokens}, Total: ${inputTokens + outputTokens} tokens`
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Bedrock stream request failed:', error);
|
||||
logger.error('❌ Bedrock stream request failed:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({ error: 'Bedrock service error', message: error.message });
|
||||
return res.status(500).json({ error: 'Bedrock service error', message: error.message })
|
||||
}
|
||||
return;
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 流式请求完成后 - 如果没有捕获到usage数据,记录警告但不进行估算
|
||||
setTimeout(() => {
|
||||
if (!usageDataCaptured) {
|
||||
logger.warn('⚠️ No usage data captured from SSE stream - no statistics recorded (official data only)');
|
||||
logger.warn(
|
||||
'⚠️ No usage data captured from SSE stream - no statistics recorded (official data only)'
|
||||
)
|
||||
}
|
||||
}, 1000); // 1秒后检查
|
||||
}, 1000) // 1秒后检查
|
||||
} else {
|
||||
// 非流式响应 - 只使用官方真实usage数据
|
||||
logger.info('📄 Starting non-streaming request', {
|
||||
apiKeyId: req.apiKey.id,
|
||||
apiKeyName: req.apiKey.name
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
|
||||
const requestedModel = req.body.model
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
|
||||
// 根据账号类型选择对应的转发服务
|
||||
let response;
|
||||
logger.debug(`[DEBUG] Request query params: ${JSON.stringify(req.query)}`);
|
||||
logger.debug(`[DEBUG] Request URL: ${req.url}`);
|
||||
logger.debug(`[DEBUG] Request path: ${req.path}`);
|
||||
|
||||
let response
|
||||
logger.debug(`[DEBUG] Request query params: ${JSON.stringify(req.query)}`)
|
||||
logger.debug(`[DEBUG] Request URL: ${req.url}`)
|
||||
logger.debug(`[DEBUG] Request path: ${req.path}`)
|
||||
|
||||
if (accountType === 'claude-official') {
|
||||
// 官方Claude账号使用原有的转发服务
|
||||
response = await claudeRelayService.relayRequest(req.body, req.apiKey, req, res, req.headers);
|
||||
response = await claudeRelayService.relayRequest(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
req,
|
||||
res,
|
||||
req.headers
|
||||
)
|
||||
} else if (accountType === 'claude-console') {
|
||||
// Claude Console账号使用Console转发服务
|
||||
logger.debug(`[DEBUG] Calling claudeConsoleRelayService.relayRequest with accountId: ${accountId}`);
|
||||
response = await claudeConsoleRelayService.relayRequest(req.body, req.apiKey, req, res, req.headers, accountId);
|
||||
logger.debug(
|
||||
`[DEBUG] Calling claudeConsoleRelayService.relayRequest with accountId: ${accountId}`
|
||||
)
|
||||
response = await claudeConsoleRelayService.relayRequest(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
req,
|
||||
res,
|
||||
req.headers,
|
||||
accountId
|
||||
)
|
||||
} else if (accountType === 'bedrock') {
|
||||
// Bedrock账号使用Bedrock转发服务
|
||||
try {
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId);
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId)
|
||||
if (!bedrockAccountResult.success) {
|
||||
throw new Error('Failed to get Bedrock account details');
|
||||
throw new Error('Failed to get Bedrock account details')
|
||||
}
|
||||
|
||||
const result = await bedrockRelayService.handleNonStreamRequest(req.body, bedrockAccountResult.data, req.headers);
|
||||
|
||||
const result = await bedrockRelayService.handleNonStreamRequest(
|
||||
req.body,
|
||||
bedrockAccountResult.data,
|
||||
req.headers
|
||||
)
|
||||
|
||||
// 构建标准响应格式
|
||||
response = {
|
||||
statusCode: result.success ? 200 : 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(result.success ? result.data : { error: result.error }),
|
||||
accountId: accountId
|
||||
};
|
||||
|
||||
accountId
|
||||
}
|
||||
|
||||
// 如果成功,添加使用统计到响应数据中
|
||||
if (result.success && result.usage) {
|
||||
const responseData = JSON.parse(response.body);
|
||||
responseData.usage = result.usage;
|
||||
response.body = JSON.stringify(responseData);
|
||||
const responseData = JSON.parse(response.body)
|
||||
responseData.usage = result.usage
|
||||
response.body = JSON.stringify(responseData)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Bedrock non-stream request failed:', error);
|
||||
logger.error('❌ Bedrock non-stream request failed:', error)
|
||||
response = {
|
||||
statusCode: 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ error: 'Bedrock service error', message: error.message }),
|
||||
accountId: accountId
|
||||
};
|
||||
accountId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
logger.info('📡 Claude API response received', {
|
||||
statusCode: response.statusCode,
|
||||
headers: JSON.stringify(response.headers),
|
||||
bodyLength: response.body ? response.body.length : 0
|
||||
});
|
||||
|
||||
res.status(response.statusCode);
|
||||
|
||||
})
|
||||
|
||||
res.status(response.statusCode)
|
||||
|
||||
// 设置响应头,避免 Content-Length 和 Transfer-Encoding 冲突
|
||||
const skipHeaders = ['content-encoding', 'transfer-encoding', 'content-length'];
|
||||
Object.keys(response.headers).forEach(key => {
|
||||
const skipHeaders = ['content-encoding', 'transfer-encoding', 'content-length']
|
||||
Object.keys(response.headers).forEach((key) => {
|
||||
if (!skipHeaders.includes(key.toLowerCase())) {
|
||||
res.setHeader(key, response.headers[key]);
|
||||
res.setHeader(key, response.headers[key])
|
||||
}
|
||||
});
|
||||
|
||||
let usageRecorded = false;
|
||||
|
||||
})
|
||||
|
||||
let usageRecorded = false
|
||||
|
||||
// 尝试解析JSON响应并提取usage信息
|
||||
try {
|
||||
const jsonData = JSON.parse(response.body);
|
||||
|
||||
logger.info('📊 Parsed Claude API response:', JSON.stringify(jsonData, null, 2));
|
||||
|
||||
const jsonData = JSON.parse(response.body)
|
||||
|
||||
logger.info('📊 Parsed Claude API response:', JSON.stringify(jsonData, null, 2))
|
||||
|
||||
// 从Claude API响应中提取usage信息(完整的token分类体系)
|
||||
if (jsonData.usage && jsonData.usage.input_tokens !== undefined && jsonData.usage.output_tokens !== undefined) {
|
||||
const inputTokens = jsonData.usage.input_tokens || 0;
|
||||
const outputTokens = jsonData.usage.output_tokens || 0;
|
||||
const cacheCreateTokens = jsonData.usage.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = jsonData.usage.cache_read_input_tokens || 0;
|
||||
const model = jsonData.model || req.body.model || 'unknown';
|
||||
|
||||
if (
|
||||
jsonData.usage &&
|
||||
jsonData.usage.input_tokens !== undefined &&
|
||||
jsonData.usage.output_tokens !== undefined
|
||||
) {
|
||||
const inputTokens = jsonData.usage.input_tokens || 0
|
||||
const outputTokens = jsonData.usage.output_tokens || 0
|
||||
const cacheCreateTokens = jsonData.usage.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = jsonData.usage.cache_read_input_tokens || 0
|
||||
const model = jsonData.model || req.body.model || 'unknown'
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const accountId = response.accountId;
|
||||
await apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model, accountId);
|
||||
|
||||
const { accountId: responseAccountId } = response
|
||||
await apiKeyService.recordUsage(
|
||||
req.apiKey.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
responseAccountId
|
||||
)
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
await redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens);
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
await redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageRecorded = true;
|
||||
logger.api(`📊 Non-stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`);
|
||||
|
||||
usageRecorded = true
|
||||
logger.api(
|
||||
`📊 Non-stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`
|
||||
)
|
||||
} else {
|
||||
logger.warn('⚠️ No usage data found in Claude API JSON response');
|
||||
logger.warn('⚠️ No usage data found in Claude API JSON response')
|
||||
}
|
||||
|
||||
res.json(jsonData);
|
||||
|
||||
res.json(jsonData)
|
||||
} catch (parseError) {
|
||||
logger.warn('⚠️ Failed to parse Claude API response as JSON:', parseError.message);
|
||||
logger.info('📄 Raw response body:', response.body);
|
||||
res.send(response.body);
|
||||
logger.warn('⚠️ Failed to parse Claude API response as JSON:', parseError.message)
|
||||
logger.info('📄 Raw response body:', response.body)
|
||||
res.send(response.body)
|
||||
}
|
||||
|
||||
|
||||
// 如果没有记录usage,只记录警告,不进行估算
|
||||
if (!usageRecorded) {
|
||||
logger.warn('⚠️ No usage data recorded for non-stream request - no statistics recorded (official data only)');
|
||||
logger.warn(
|
||||
'⚠️ No usage data recorded for non-stream request - no statistics recorded (official data only)'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.api(`✅ Request completed in ${duration}ms for key: ${req.apiKey.name}`);
|
||||
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.api(`✅ Request completed in ${duration}ms for key: ${req.apiKey.name}`)
|
||||
return undefined
|
||||
} catch (error) {
|
||||
logger.error('❌ Claude relay error:', error.message, {
|
||||
code: error.code,
|
||||
stack: error.stack
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 确保在任何情况下都能返回有效的JSON响应
|
||||
if (!res.headersSent) {
|
||||
// 根据错误类型设置适当的状态码
|
||||
let statusCode = 500;
|
||||
let errorType = 'Relay service error';
|
||||
|
||||
let statusCode = 500
|
||||
let errorType = 'Relay service error'
|
||||
|
||||
if (error.message.includes('Connection reset') || error.message.includes('socket hang up')) {
|
||||
statusCode = 502;
|
||||
errorType = 'Upstream connection error';
|
||||
statusCode = 502
|
||||
errorType = 'Upstream connection error'
|
||||
} else if (error.message.includes('Connection refused')) {
|
||||
statusCode = 502;
|
||||
errorType = 'Upstream service unavailable';
|
||||
statusCode = 502
|
||||
errorType = 'Upstream service unavailable'
|
||||
} else if (error.message.includes('timeout')) {
|
||||
statusCode = 504;
|
||||
errorType = 'Upstream timeout';
|
||||
statusCode = 504
|
||||
errorType = 'Upstream timeout'
|
||||
} else if (error.message.includes('resolve') || error.message.includes('ENOTFOUND')) {
|
||||
statusCode = 502;
|
||||
errorType = 'Upstream hostname resolution failed';
|
||||
statusCode = 502
|
||||
errorType = 'Upstream hostname resolution failed'
|
||||
}
|
||||
|
||||
res.status(statusCode).json({
|
||||
|
||||
return res.status(statusCode).json({
|
||||
error: errorType,
|
||||
message: error.message || 'An unexpected error occurred',
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} else {
|
||||
// 如果响应头已经发送,尝试结束响应
|
||||
if (!res.destroyed && !res.finished) {
|
||||
res.end();
|
||||
res.end()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 🚀 Claude API messages 端点 - /api/v1/messages
|
||||
router.post('/v1/messages', authenticateApiKey, handleMessagesRequest);
|
||||
router.post('/v1/messages', authenticateApiKey, handleMessagesRequest)
|
||||
|
||||
// 🚀 Claude API messages 端点 - /claude/v1/messages (别名)
|
||||
router.post('/claude/v1/messages', authenticateApiKey, handleMessagesRequest);
|
||||
router.post('/claude/v1/messages', authenticateApiKey, handleMessagesRequest)
|
||||
|
||||
// 📋 模型列表端点 - Claude Code 客户端需要
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
@@ -368,66 +490,65 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
owned_by: 'anthropic'
|
||||
},
|
||||
{
|
||||
id: 'claude-3-5-haiku-20241022',
|
||||
id: 'claude-3-5-haiku-20241022',
|
||||
object: 'model',
|
||||
created: 1669599635,
|
||||
owned_by: 'anthropic'
|
||||
},
|
||||
{
|
||||
id: 'claude-3-opus-20240229',
|
||||
object: 'model',
|
||||
object: 'model',
|
||||
created: 1669599635,
|
||||
owned_by: 'anthropic'
|
||||
},
|
||||
{
|
||||
id: 'claude-sonnet-4-20250514',
|
||||
object: 'model',
|
||||
created: 1669599635,
|
||||
created: 1669599635,
|
||||
owned_by: 'anthropic'
|
||||
}
|
||||
];
|
||||
|
||||
]
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Models list error:', error);
|
||||
logger.error('❌ Models list error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get models list',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🏥 健康检查端点
|
||||
router.get('/health', async (req, res) => {
|
||||
try {
|
||||
const healthStatus = await claudeRelayService.healthCheck();
|
||||
|
||||
const healthStatus = await claudeRelayService.healthCheck()
|
||||
|
||||
res.status(healthStatus.healthy ? 200 : 503).json({
|
||||
status: healthStatus.healthy ? 'healthy' : 'unhealthy',
|
||||
service: 'claude-relay-service',
|
||||
version: '1.0.0',
|
||||
...healthStatus
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Health check error:', error);
|
||||
logger.error('❌ Health check error:', error)
|
||||
res.status(503).json({
|
||||
status: 'unhealthy',
|
||||
service: 'claude-relay-service',
|
||||
error: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📊 API Key状态检查端点 - /api/v1/key-info
|
||||
router.get('/v1/key-info', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id);
|
||||
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id)
|
||||
|
||||
res.json({
|
||||
keyInfo: {
|
||||
id: req.apiKey.id,
|
||||
@@ -436,21 +557,21 @@ router.get('/v1/key-info', authenticateApiKey, async (req, res) => {
|
||||
usage
|
||||
},
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Key info error:', error);
|
||||
logger.error('❌ Key info error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get key info',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📈 使用统计端点 - /api/v1/usage
|
||||
router.get('/v1/usage', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id);
|
||||
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id)
|
||||
|
||||
res.json({
|
||||
usage,
|
||||
limits: {
|
||||
@@ -458,56 +579,56 @@ router.get('/v1/usage', authenticateApiKey, async (req, res) => {
|
||||
requests: 0 // 请求限制已移除
|
||||
},
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Usage stats error:', error);
|
||||
logger.error('❌ Usage stats error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get usage stats',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 👤 用户信息端点 - Claude Code 客户端需要
|
||||
router.get('/v1/me', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
// 返回基础用户信息
|
||||
res.json({
|
||||
id: 'user_' + req.apiKey.id,
|
||||
type: 'user',
|
||||
id: `user_${req.apiKey.id}`,
|
||||
type: 'user',
|
||||
display_name: req.apiKey.name || 'API User',
|
||||
created_at: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ User info error:', error);
|
||||
logger.error('❌ User info error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get user info',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 💰 余额/限制端点 - Claude Code 客户端需要
|
||||
router.get('/v1/organizations/:org_id/usage', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id);
|
||||
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id)
|
||||
|
||||
res.json({
|
||||
object: 'usage',
|
||||
data: [
|
||||
{
|
||||
type: 'credit_balance',
|
||||
type: 'credit_balance',
|
||||
credit_balance: req.apiKey.tokenLimit - (usage.totalTokens || 0)
|
||||
}
|
||||
]
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Organization usage error:', error);
|
||||
logger.error('❌ Organization usage error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get usage info',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
const express = require('express');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const CostCalculator = require('../utils/costCalculator');
|
||||
const express = require('express')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const CostCalculator = require('../utils/costCalculator')
|
||||
|
||||
const router = express.Router();
|
||||
const router = express.Router()
|
||||
|
||||
// 🏠 重定向页面请求到新版 admin-spa
|
||||
router.get('/', (req, res) => {
|
||||
res.redirect(301, '/admin-next/api-stats');
|
||||
});
|
||||
res.redirect(301, '/admin-next/api-stats')
|
||||
})
|
||||
|
||||
// 🔑 获取 API Key 对应的 ID
|
||||
router.post('/api/get-key-id', async (req, res) => {
|
||||
try {
|
||||
const { apiKey } = req.body;
|
||||
|
||||
const { apiKey } = req.body
|
||||
|
||||
if (!apiKey) {
|
||||
return res.status(400).json({
|
||||
error: 'API Key is required',
|
||||
message: 'Please provide your API Key'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 基本API Key格式验证
|
||||
@@ -28,108 +28,110 @@ router.post('/api/get-key-id', async (req, res) => {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API key format',
|
||||
message: 'API key format is invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证API Key
|
||||
const validation = await apiKeyService.validateApiKey(apiKey);
|
||||
|
||||
const validation = await apiKeyService.validateApiKey(apiKey)
|
||||
|
||||
if (!validation.valid) {
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown';
|
||||
logger.security(`🔒 Invalid API key in get-key-id: ${validation.error} from ${clientIP}`);
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown'
|
||||
logger.security(`🔒 Invalid API key in get-key-id: ${validation.error} from ${clientIP}`)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid API key',
|
||||
message: validation.error
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const keyData = validation.keyData;
|
||||
|
||||
res.json({
|
||||
const { keyData } = validation
|
||||
|
||||
return res.json({
|
||||
success: true,
|
||||
data: {
|
||||
id: keyData.id
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get API key ID:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Failed to get API key ID:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: 'Failed to retrieve API key ID'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📊 用户API Key统计查询接口 - 安全的自查询接口
|
||||
router.post('/api/user-stats', async (req, res) => {
|
||||
try {
|
||||
const { apiKey, apiId } = req.body;
|
||||
|
||||
let keyData;
|
||||
let keyId;
|
||||
|
||||
const { apiKey, apiId } = req.body
|
||||
|
||||
let keyData
|
||||
let keyId
|
||||
|
||||
if (apiId) {
|
||||
// 通过 apiId 查询
|
||||
if (typeof apiId !== 'string' || !apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)) {
|
||||
if (
|
||||
typeof apiId !== 'string' ||
|
||||
!apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)
|
||||
) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API ID format',
|
||||
message: 'API ID must be a valid UUID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 直接通过 ID 获取 API Key 数据
|
||||
keyData = await redis.getApiKey(apiId);
|
||||
|
||||
keyData = await redis.getApiKey(apiId)
|
||||
|
||||
if (!keyData || Object.keys(keyData).length === 0) {
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`)
|
||||
return res.status(404).json({
|
||||
error: 'API key not found',
|
||||
message: 'The specified API key does not exist'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查是否激活
|
||||
if (keyData.isActive !== 'true') {
|
||||
return res.status(403).json({
|
||||
error: 'API key is disabled',
|
||||
message: 'This API key has been disabled'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查是否过期
|
||||
if (keyData.expiresAt && new Date() > new Date(keyData.expiresAt)) {
|
||||
return res.status(403).json({
|
||||
error: 'API key has expired',
|
||||
message: 'This API key has expired'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyId = apiId;
|
||||
|
||||
|
||||
keyId = apiId
|
||||
|
||||
// 获取使用统计
|
||||
const usage = await redis.getUsageStats(keyId);
|
||||
|
||||
const usage = await redis.getUsageStats(keyId)
|
||||
|
||||
// 获取当日费用统计
|
||||
const dailyCost = await redis.getDailyCost(keyId);
|
||||
|
||||
const dailyCost = await redis.getDailyCost(keyId)
|
||||
|
||||
// 处理数据格式,与 validateApiKey 返回的格式保持一致
|
||||
// 解析限制模型数据
|
||||
let restrictedModels = [];
|
||||
let restrictedModels = []
|
||||
try {
|
||||
restrictedModels = keyData.restrictedModels ? JSON.parse(keyData.restrictedModels) : [];
|
||||
restrictedModels = keyData.restrictedModels ? JSON.parse(keyData.restrictedModels) : []
|
||||
} catch (e) {
|
||||
restrictedModels = [];
|
||||
restrictedModels = []
|
||||
}
|
||||
|
||||
|
||||
// 解析允许的客户端数据
|
||||
let allowedClients = [];
|
||||
let allowedClients = []
|
||||
try {
|
||||
allowedClients = keyData.allowedClients ? JSON.parse(keyData.allowedClients) : [];
|
||||
allowedClients = keyData.allowedClients ? JSON.parse(keyData.allowedClients) : []
|
||||
} catch (e) {
|
||||
allowedClients = [];
|
||||
allowedClients = []
|
||||
}
|
||||
|
||||
|
||||
// 格式化 keyData
|
||||
keyData = {
|
||||
...keyData,
|
||||
@@ -140,70 +142,75 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
dailyCostLimit: parseFloat(keyData.dailyCostLimit) || 0,
|
||||
dailyCost: dailyCost || 0,
|
||||
enableModelRestriction: keyData.enableModelRestriction === 'true',
|
||||
restrictedModels: restrictedModels,
|
||||
restrictedModels,
|
||||
enableClientRestriction: keyData.enableClientRestriction === 'true',
|
||||
allowedClients: allowedClients,
|
||||
allowedClients,
|
||||
permissions: keyData.permissions || 'all',
|
||||
usage: usage // 使用完整的 usage 数据,而不是只有 total
|
||||
};
|
||||
|
||||
usage // 使用完整的 usage 数据,而不是只有 total
|
||||
}
|
||||
} else if (apiKey) {
|
||||
// 通过 apiKey 查询(保持向后兼容)
|
||||
if (typeof apiKey !== 'string' || apiKey.length < 10 || apiKey.length > 512) {
|
||||
logger.security(`🔒 Invalid API key format in user stats query from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 Invalid API key format in user stats query from ${req.ip || 'unknown'}`)
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API key format',
|
||||
message: 'API key format is invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证API Key(重用现有的验证逻辑)
|
||||
const validation = await apiKeyService.validateApiKey(apiKey);
|
||||
|
||||
const validation = await apiKeyService.validateApiKey(apiKey)
|
||||
|
||||
if (!validation.valid) {
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown';
|
||||
logger.security(`🔒 Invalid API key in user stats query: ${validation.error} from ${clientIP}`);
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown'
|
||||
logger.security(
|
||||
`🔒 Invalid API key in user stats query: ${validation.error} from ${clientIP}`
|
||||
)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid API key',
|
||||
message: validation.error
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyData = validation.keyData;
|
||||
keyId = keyData.id;
|
||||
|
||||
const { keyData: validatedKeyData } = validation
|
||||
keyData = validatedKeyData
|
||||
keyId = keyData.id
|
||||
} else {
|
||||
logger.security(`🔒 Missing API key or ID in user stats query from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 Missing API key or ID in user stats query from ${req.ip || 'unknown'}`)
|
||||
return res.status(400).json({
|
||||
error: 'API Key or ID is required',
|
||||
message: 'Please provide your API Key or API ID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 记录合法查询
|
||||
logger.api(`📊 User stats query from key: ${keyData.name} (${keyId}) from ${req.ip || 'unknown'}`);
|
||||
logger.api(
|
||||
`📊 User stats query from key: ${keyData.name} (${keyId}) from ${req.ip || 'unknown'}`
|
||||
)
|
||||
|
||||
// 获取验证结果中的完整keyData(包含isActive状态和cost信息)
|
||||
const fullKeyData = keyData;
|
||||
|
||||
const fullKeyData = keyData
|
||||
|
||||
// 计算总费用 - 使用与模型统计相同的逻辑(按模型分别计算)
|
||||
let totalCost = 0;
|
||||
let formattedCost = '$0.000000';
|
||||
|
||||
let totalCost = 0
|
||||
let formattedCost = '$0.000000'
|
||||
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 获取所有月度模型统计(与model-stats接口相同的逻辑)
|
||||
const allModelKeys = await client.keys(`usage:${keyId}:model:monthly:*:*`);
|
||||
const modelUsageMap = new Map();
|
||||
|
||||
const allModelKeys = await client.keys(`usage:${keyId}:model:monthly:*:*`)
|
||||
const modelUsageMap = new Map()
|
||||
|
||||
for (const key of allModelKeys) {
|
||||
const modelMatch = key.match(/usage:.+:model:monthly:(.+):(\d{4}-\d{2})$/);
|
||||
if (!modelMatch) continue;
|
||||
|
||||
const model = modelMatch[1];
|
||||
const data = await client.hgetall(key);
|
||||
|
||||
const modelMatch = key.match(/usage:.+:model:monthly:(.+):(\d{4}-\d{2})$/)
|
||||
if (!modelMatch) {
|
||||
continue
|
||||
}
|
||||
|
||||
const model = modelMatch[1]
|
||||
const data = await client.hgetall(key)
|
||||
|
||||
if (data && Object.keys(data).length > 0) {
|
||||
if (!modelUsageMap.has(model)) {
|
||||
modelUsageMap.set(model, {
|
||||
@@ -211,17 +218,17 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
outputTokens: 0,
|
||||
cacheCreateTokens: 0,
|
||||
cacheReadTokens: 0
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const modelUsage = modelUsageMap.get(model);
|
||||
modelUsage.inputTokens += parseInt(data.inputTokens) || 0;
|
||||
modelUsage.outputTokens += parseInt(data.outputTokens) || 0;
|
||||
modelUsage.cacheCreateTokens += parseInt(data.cacheCreateTokens) || 0;
|
||||
modelUsage.cacheReadTokens += parseInt(data.cacheReadTokens) || 0;
|
||||
|
||||
const modelUsage = modelUsageMap.get(model)
|
||||
modelUsage.inputTokens += parseInt(data.inputTokens) || 0
|
||||
modelUsage.outputTokens += parseInt(data.outputTokens) || 0
|
||||
modelUsage.cacheCreateTokens += parseInt(data.cacheCreateTokens) || 0
|
||||
modelUsage.cacheReadTokens += parseInt(data.cacheReadTokens) || 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 按模型计算费用并汇总
|
||||
for (const [model, usage] of modelUsageMap) {
|
||||
const usageData = {
|
||||
@@ -229,66 +236,65 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
output_tokens: usage.outputTokens,
|
||||
cache_creation_input_tokens: usage.cacheCreateTokens,
|
||||
cache_read_input_tokens: usage.cacheReadTokens
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(usageData, model);
|
||||
totalCost += costResult.costs.total;
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(usageData, model)
|
||||
totalCost += costResult.costs.total
|
||||
}
|
||||
|
||||
|
||||
// 如果没有模型级别的详细数据,回退到总体数据计算
|
||||
if (modelUsageMap.size === 0 && fullKeyData.usage?.total?.allTokens > 0) {
|
||||
const usage = fullKeyData.usage.total;
|
||||
const usage = fullKeyData.usage.total
|
||||
const costUsage = {
|
||||
input_tokens: usage.inputTokens || 0,
|
||||
output_tokens: usage.outputTokens || 0,
|
||||
cache_creation_input_tokens: usage.cacheCreateTokens || 0,
|
||||
cache_read_input_tokens: usage.cacheReadTokens || 0
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022');
|
||||
totalCost = costResult.costs.total;
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022')
|
||||
totalCost = costResult.costs.total
|
||||
}
|
||||
|
||||
formattedCost = CostCalculator.formatCost(totalCost);
|
||||
|
||||
|
||||
formattedCost = CostCalculator.formatCost(totalCost)
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to calculate detailed cost for key ${keyId}:`, error);
|
||||
logger.warn(`Failed to calculate detailed cost for key ${keyId}:`, error)
|
||||
// 回退到简单计算
|
||||
if (fullKeyData.usage?.total?.allTokens > 0) {
|
||||
const usage = fullKeyData.usage.total;
|
||||
const usage = fullKeyData.usage.total
|
||||
const costUsage = {
|
||||
input_tokens: usage.inputTokens || 0,
|
||||
output_tokens: usage.outputTokens || 0,
|
||||
cache_creation_input_tokens: usage.cacheCreateTokens || 0,
|
||||
cache_read_input_tokens: usage.cacheReadTokens || 0
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022');
|
||||
totalCost = costResult.costs.total;
|
||||
formattedCost = costResult.formatted.total;
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022')
|
||||
totalCost = costResult.costs.total
|
||||
formattedCost = costResult.formatted.total
|
||||
}
|
||||
}
|
||||
|
||||
// 获取当前使用量
|
||||
let currentWindowRequests = 0;
|
||||
let currentWindowTokens = 0;
|
||||
let currentDailyCost = 0;
|
||||
|
||||
let currentWindowRequests = 0
|
||||
let currentWindowTokens = 0
|
||||
let currentDailyCost = 0
|
||||
|
||||
try {
|
||||
// 获取当前时间窗口的请求次数和Token使用量
|
||||
if (fullKeyData.rateLimitWindow > 0) {
|
||||
const client = redis.getClientSafe();
|
||||
const requestCountKey = `rate_limit:requests:${keyId}`;
|
||||
const tokenCountKey = `rate_limit:tokens:${keyId}`;
|
||||
|
||||
currentWindowRequests = parseInt(await client.get(requestCountKey) || '0');
|
||||
currentWindowTokens = parseInt(await client.get(tokenCountKey) || '0');
|
||||
const client = redis.getClientSafe()
|
||||
const requestCountKey = `rate_limit:requests:${keyId}`
|
||||
const tokenCountKey = `rate_limit:tokens:${keyId}`
|
||||
|
||||
currentWindowRequests = parseInt((await client.get(requestCountKey)) || '0')
|
||||
currentWindowTokens = parseInt((await client.get(tokenCountKey)) || '0')
|
||||
}
|
||||
|
||||
|
||||
// 获取当日费用
|
||||
currentDailyCost = await redis.getDailyCost(keyId) || 0;
|
||||
currentDailyCost = (await redis.getDailyCost(keyId)) || 0
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to get current usage for key ${keyId}:`, error);
|
||||
logger.warn(`Failed to get current usage for key ${keyId}:`, error)
|
||||
}
|
||||
|
||||
// 构建响应数据(只返回该API Key自己的信息,确保不泄露其他信息)
|
||||
@@ -300,7 +306,7 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
createdAt: keyData.createdAt,
|
||||
expiresAt: keyData.expiresAt,
|
||||
permissions: fullKeyData.permissions,
|
||||
|
||||
|
||||
// 使用统计(使用验证结果中的完整数据)
|
||||
usage: {
|
||||
total: {
|
||||
@@ -314,10 +320,10 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
cacheReadTokens: 0
|
||||
}),
|
||||
cost: totalCost,
|
||||
formattedCost: formattedCost
|
||||
formattedCost
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
// 限制信息(显示配置和当前使用量)
|
||||
limits: {
|
||||
tokenLimit: fullKeyData.tokenLimit || 0,
|
||||
@@ -326,17 +332,23 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
rateLimitRequests: fullKeyData.rateLimitRequests || 0,
|
||||
dailyCostLimit: fullKeyData.dailyCostLimit || 0,
|
||||
// 当前使用量
|
||||
currentWindowRequests: currentWindowRequests,
|
||||
currentWindowTokens: currentWindowTokens,
|
||||
currentDailyCost: currentDailyCost
|
||||
currentWindowRequests,
|
||||
currentWindowTokens,
|
||||
currentDailyCost
|
||||
},
|
||||
|
||||
|
||||
// 绑定的账户信息(只显示ID,不显示敏感信息)
|
||||
accounts: {
|
||||
claudeAccountId: fullKeyData.claudeAccountId && fullKeyData.claudeAccountId !== '' ? fullKeyData.claudeAccountId : null,
|
||||
geminiAccountId: fullKeyData.geminiAccountId && fullKeyData.geminiAccountId !== '' ? fullKeyData.geminiAccountId : null
|
||||
claudeAccountId:
|
||||
fullKeyData.claudeAccountId && fullKeyData.claudeAccountId !== ''
|
||||
? fullKeyData.claudeAccountId
|
||||
: null,
|
||||
geminiAccountId:
|
||||
fullKeyData.geminiAccountId && fullKeyData.geminiAccountId !== ''
|
||||
? fullKeyData.geminiAccountId
|
||||
: null
|
||||
},
|
||||
|
||||
|
||||
// 模型和客户端限制信息
|
||||
restrictions: {
|
||||
enableModelRestriction: fullKeyData.enableModelRestriction || false,
|
||||
@@ -344,126 +356,137 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
enableClientRestriction: fullKeyData.enableClientRestriction || false,
|
||||
allowedClients: fullKeyData.allowedClients || []
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
data: responseData
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to process user stats query:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Failed to process user stats query:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: 'Failed to retrieve API key statistics'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📊 用户模型统计查询接口 - 安全的自查询接口
|
||||
router.post('/api/user-model-stats', async (req, res) => {
|
||||
try {
|
||||
const { apiKey, apiId, period = 'monthly' } = req.body;
|
||||
|
||||
let keyData;
|
||||
let keyId;
|
||||
|
||||
const { apiKey, apiId, period = 'monthly' } = req.body
|
||||
|
||||
let keyData
|
||||
let keyId
|
||||
|
||||
if (apiId) {
|
||||
// 通过 apiId 查询
|
||||
if (typeof apiId !== 'string' || !apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)) {
|
||||
if (
|
||||
typeof apiId !== 'string' ||
|
||||
!apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)
|
||||
) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API ID format',
|
||||
message: 'API ID must be a valid UUID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 直接通过 ID 获取 API Key 数据
|
||||
keyData = await redis.getApiKey(apiId);
|
||||
|
||||
keyData = await redis.getApiKey(apiId)
|
||||
|
||||
if (!keyData || Object.keys(keyData).length === 0) {
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`)
|
||||
return res.status(404).json({
|
||||
error: 'API key not found',
|
||||
message: 'The specified API key does not exist'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查是否激活
|
||||
if (keyData.isActive !== 'true') {
|
||||
return res.status(403).json({
|
||||
error: 'API key is disabled',
|
||||
message: 'This API key has been disabled'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyId = apiId;
|
||||
|
||||
|
||||
keyId = apiId
|
||||
|
||||
// 获取使用统计
|
||||
const usage = await redis.getUsageStats(keyId);
|
||||
keyData.usage = { total: usage.total };
|
||||
|
||||
const usage = await redis.getUsageStats(keyId)
|
||||
keyData.usage = { total: usage.total }
|
||||
} else if (apiKey) {
|
||||
// 通过 apiKey 查询(保持向后兼容)
|
||||
// 验证API Key
|
||||
const validation = await apiKeyService.validateApiKey(apiKey);
|
||||
|
||||
const validation = await apiKeyService.validateApiKey(apiKey)
|
||||
|
||||
if (!validation.valid) {
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown';
|
||||
logger.security(`🔒 Invalid API key in user model stats query: ${validation.error} from ${clientIP}`);
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown'
|
||||
logger.security(
|
||||
`🔒 Invalid API key in user model stats query: ${validation.error} from ${clientIP}`
|
||||
)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid API key',
|
||||
message: validation.error
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyData = validation.keyData;
|
||||
keyId = keyData.id;
|
||||
|
||||
const { keyData: validatedKeyData } = validation
|
||||
keyData = validatedKeyData
|
||||
keyId = keyData.id
|
||||
} else {
|
||||
logger.security(`🔒 Missing API key or ID in user model stats query from ${req.ip || 'unknown'}`);
|
||||
logger.security(
|
||||
`🔒 Missing API key or ID in user model stats query from ${req.ip || 'unknown'}`
|
||||
)
|
||||
return res.status(400).json({
|
||||
error: 'API Key or ID is required',
|
||||
message: 'Please provide your API Key or API ID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
logger.api(`📊 User model stats query from key: ${keyData.name} (${keyId}) for period: ${period}`);
|
||||
|
||||
logger.api(
|
||||
`📊 User model stats query from key: ${keyData.name} (${keyId}) for period: ${period}`
|
||||
)
|
||||
|
||||
// 重用管理后台的模型统计逻辑,但只返回该API Key的数据
|
||||
const client = redis.getClientSafe();
|
||||
const client = redis.getClientSafe()
|
||||
// 使用与管理页面相同的时区处理逻辑
|
||||
const tzDate = redis.getDateInTimezone();
|
||||
const today = redis.getDateStringInTimezone();
|
||||
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`;
|
||||
|
||||
const pattern = period === 'daily' ?
|
||||
`usage:${keyId}:model:daily:*:${today}` :
|
||||
`usage:${keyId}:model:monthly:*:${currentMonth}`;
|
||||
|
||||
const keys = await client.keys(pattern);
|
||||
const modelStats = [];
|
||||
|
||||
const tzDate = redis.getDateInTimezone()
|
||||
const today = redis.getDateStringInTimezone()
|
||||
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`
|
||||
|
||||
const pattern =
|
||||
period === 'daily'
|
||||
? `usage:${keyId}:model:daily:*:${today}`
|
||||
: `usage:${keyId}:model:monthly:*:${currentMonth}`
|
||||
|
||||
const keys = await client.keys(pattern)
|
||||
const modelStats = []
|
||||
|
||||
for (const key of keys) {
|
||||
const match = key.match(period === 'daily' ?
|
||||
/usage:.+:model:daily:(.+):\d{4}-\d{2}-\d{2}$/ :
|
||||
/usage:.+:model:monthly:(.+):\d{4}-\d{2}$/
|
||||
);
|
||||
|
||||
if (!match) continue;
|
||||
|
||||
const model = match[1];
|
||||
const data = await client.hgetall(key);
|
||||
|
||||
const match = key.match(
|
||||
period === 'daily'
|
||||
? /usage:.+:model:daily:(.+):\d{4}-\d{2}-\d{2}$/
|
||||
: /usage:.+:model:monthly:(.+):\d{4}-\d{2}$/
|
||||
)
|
||||
|
||||
if (!match) {
|
||||
continue
|
||||
}
|
||||
|
||||
const model = match[1]
|
||||
const data = await client.hgetall(key)
|
||||
|
||||
if (data && Object.keys(data).length > 0) {
|
||||
const usage = {
|
||||
input_tokens: parseInt(data.inputTokens) || 0,
|
||||
output_tokens: parseInt(data.outputTokens) || 0,
|
||||
cache_creation_input_tokens: parseInt(data.cacheCreateTokens) || 0,
|
||||
cache_read_input_tokens: parseInt(data.cacheReadTokens) || 0
|
||||
};
|
||||
|
||||
const costData = CostCalculator.calculateCost(usage, model);
|
||||
|
||||
}
|
||||
|
||||
const costData = CostCalculator.calculateCost(usage, model)
|
||||
|
||||
modelStats.push({
|
||||
model,
|
||||
requests: parseInt(data.requests) || 0,
|
||||
@@ -475,32 +498,31 @@ router.post('/api/user-model-stats', async (req, res) => {
|
||||
costs: costData.costs,
|
||||
formatted: costData.formatted,
|
||||
pricing: costData.pricing
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有详细的模型数据,不显示历史数据以避免混淆
|
||||
// 只有在查询特定时间段时返回空数组,表示该时间段确实没有数据
|
||||
if (modelStats.length === 0) {
|
||||
logger.info(`📊 No model stats found for key ${keyId} in period ${period}`);
|
||||
logger.info(`📊 No model stats found for key ${keyId} in period ${period}`)
|
||||
}
|
||||
|
||||
// 按总token数降序排列
|
||||
modelStats.sort((a, b) => b.allTokens - a.allTokens);
|
||||
modelStats.sort((a, b) => b.allTokens - a.allTokens)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
data: modelStats,
|
||||
period: period
|
||||
});
|
||||
|
||||
period
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to process user model stats query:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Failed to process user model stats query:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: 'Failed to retrieve model statistics'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const logger = require('../utils/logger');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const geminiAccountService = require('../services/geminiAccountService');
|
||||
const { sendGeminiRequest, getAvailableModels } = require('../services/geminiRelayService');
|
||||
const crypto = require('crypto');
|
||||
const sessionHelper = require('../utils/sessionHelper');
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const express = require('express')
|
||||
const router = express.Router()
|
||||
const logger = require('../utils/logger')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const geminiAccountService = require('../services/geminiAccountService')
|
||||
const { sendGeminiRequest, getAvailableModels } = require('../services/geminiRelayService')
|
||||
const crypto = require('crypto')
|
||||
const sessionHelper = require('../utils/sessionHelper')
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
// const { OAuth2Client } = require('google-auth-library'); // OAuth2Client is not used in this file
|
||||
|
||||
// 生成会话哈希
|
||||
@@ -16,24 +16,26 @@ function generateSessionHash(req) {
|
||||
req.headers['user-agent'],
|
||||
req.ip,
|
||||
req.headers['x-api-key']?.substring(0, 10)
|
||||
].filter(Boolean).join(':');
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(':')
|
||||
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex');
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex')
|
||||
}
|
||||
|
||||
// 检查 API Key 权限
|
||||
function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
|
||||
const permissions = apiKeyData.permissions || 'all';
|
||||
return permissions === 'all' || permissions === requiredPermission;
|
||||
const permissions = apiKeyData.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// Gemini 消息处理端点
|
||||
router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
const startTime = Date.now();
|
||||
let abortController = null;
|
||||
const startTime = Date.now()
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
@@ -42,7 +44,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 提取请求参数
|
||||
@@ -52,7 +54,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
temperature = 0.7,
|
||||
max_tokens = 4096,
|
||||
stream = false
|
||||
} = req.body;
|
||||
} = req.body
|
||||
|
||||
// 验证必需参数
|
||||
if (!messages || !Array.isArray(messages) || messages.length === 0) {
|
||||
@@ -61,57 +63,58 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
message: 'Messages array is required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 生成会话哈希用于粘性会话
|
||||
const sessionHash = generateSessionHash(req);
|
||||
const sessionHash = generateSessionHash(req)
|
||||
|
||||
// 使用统一调度选择可用的 Gemini 账户(传递请求的模型)
|
||||
let accountId;
|
||||
let accountId
|
||||
try {
|
||||
const schedulerResult = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
sessionHash,
|
||||
model // 传递请求的模型进行过滤
|
||||
);
|
||||
accountId = schedulerResult.accountId;
|
||||
model // 传递请求的模型进行过滤
|
||||
)
|
||||
const { accountId: selectedAccountId } = schedulerResult
|
||||
accountId = selectedAccountId
|
||||
} catch (error) {
|
||||
logger.error('Failed to select Gemini account:', error);
|
||||
logger.error('Failed to select Gemini account:', error)
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: error.message || 'No available Gemini accounts',
|
||||
type: 'service_unavailable'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取账户详情
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
if (!account) {
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: 'Selected account not found',
|
||||
type: 'service_unavailable'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`);
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`)
|
||||
|
||||
// 标记账户被使用
|
||||
await geminiAccountService.markAccountUsed(account.id);
|
||||
await geminiAccountService.markAccountUsed(account.id)
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开连接
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('Client disconnected, aborting Gemini request');
|
||||
abortController.abort();
|
||||
logger.info('Client disconnected, aborting Gemini request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 发送请求到 Gemini
|
||||
const geminiResponse = await sendGeminiRequest({
|
||||
@@ -126,64 +129,64 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
signal: abortController.signal,
|
||||
projectId: account.projectId,
|
||||
accountId: account.id
|
||||
});
|
||||
})
|
||||
|
||||
if (stream) {
|
||||
// 设置流式响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 流式传输响应
|
||||
for await (const chunk of geminiResponse) {
|
||||
if (abortController.signal.aborted) {
|
||||
break;
|
||||
break
|
||||
}
|
||||
res.write(chunk);
|
||||
res.write(chunk)
|
||||
}
|
||||
|
||||
res.end();
|
||||
res.end()
|
||||
} else {
|
||||
// 非流式响应
|
||||
res.json(geminiResponse);
|
||||
res.json(geminiResponse)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`Gemini request completed in ${duration}ms`);
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`Gemini request completed in ${duration}ms`)
|
||||
} catch (error) {
|
||||
logger.error('Gemini request error:', error);
|
||||
logger.error('Gemini request error:', error)
|
||||
|
||||
// 处理速率限制
|
||||
if (error.status === 429) {
|
||||
if (req.apiKey && req.account) {
|
||||
await geminiAccountService.setAccountRateLimited(req.account.id, true);
|
||||
await geminiAccountService.setAccountRateLimited(req.account.id, true)
|
||||
}
|
||||
}
|
||||
|
||||
// 返回错误响应
|
||||
const status = error.status || 500;
|
||||
const status = error.status || 500
|
||||
const errorResponse = {
|
||||
error: error.error || {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
res.status(status).json(errorResponse);
|
||||
res.status(status).json(errorResponse)
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 获取可用模型列表
|
||||
router.get('/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
@@ -192,16 +195,20 @@ router.get('/models', authenticateApiKey, async (req, res) => {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 选择账户获取模型列表
|
||||
let account = null;
|
||||
let account = null
|
||||
try {
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(apiKeyData, null, null);
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId);
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
null,
|
||||
null
|
||||
)
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error);
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error)
|
||||
}
|
||||
|
||||
if (!account) {
|
||||
@@ -216,32 +223,32 @@ router.get('/models', authenticateApiKey, async (req, res) => {
|
||||
owned_by: 'google'
|
||||
}
|
||||
]
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取模型列表
|
||||
const models = await getAvailableModels(account.accessToken, account.proxy);
|
||||
const models = await getAvailableModels(account.accessToken, account.proxy)
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Gemini models:', error);
|
||||
logger.error('Failed to get Gemini models:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 使用情况统计(与 Claude 共用)
|
||||
router.get('/usage', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = req.apiKey.usage;
|
||||
const { usage } = req.apiKey
|
||||
|
||||
res.json({
|
||||
object: 'usage',
|
||||
@@ -251,22 +258,22 @@ router.get('/usage', authenticateApiKey, async (req, res) => {
|
||||
daily_requests: usage.daily.requests,
|
||||
monthly_tokens: usage.monthly.tokens,
|
||||
monthly_requests: usage.monthly.requests
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get usage stats:', error);
|
||||
logger.error('Failed to get usage stats:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve usage statistics',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// API Key 信息(与 Claude 共用)
|
||||
router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const keyData = req.apiKey;
|
||||
const keyData = req.apiKey
|
||||
|
||||
res.json({
|
||||
id: keyData.id,
|
||||
@@ -274,9 +281,10 @@ router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
permissions: keyData.permissions || 'all',
|
||||
token_limit: keyData.tokenLimit,
|
||||
tokens_used: keyData.usage.total.tokens,
|
||||
tokens_remaining: keyData.tokenLimit > 0
|
||||
? Math.max(0, keyData.tokenLimit - keyData.usage.total.tokens)
|
||||
: null,
|
||||
tokens_remaining:
|
||||
keyData.tokenLimit > 0
|
||||
? Math.max(0, keyData.tokenLimit - keyData.usage.total.tokens)
|
||||
: null,
|
||||
rate_limit: {
|
||||
window: keyData.rateLimitWindow,
|
||||
requests: keyData.rateLimitRequests
|
||||
@@ -286,88 +294,105 @@ router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
enabled: keyData.enableModelRestriction,
|
||||
models: keyData.restrictedModels
|
||||
}
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get key info:', error);
|
||||
logger.error('Failed to get key info:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve API key information',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 共用的 loadCodeAssist 处理函数
|
||||
async function handleLoadCodeAssist(req, res) {
|
||||
try {
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
|
||||
logger.info(`accessToken: ${accessToken}`);
|
||||
const requestedModel = req.body.model
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId)
|
||||
logger.info(`accessToken: ${accessToken}`)
|
||||
|
||||
const { metadata, cloudaicompanionProject } = req.body;
|
||||
const { metadata, cloudaicompanionProject } = req.body
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`LoadCodeAssist request (${version})`, {
|
||||
metadata: metadata || {},
|
||||
cloudaicompanionProject: cloudaicompanionProject || null,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject);
|
||||
res.json(response);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject)
|
||||
res.json(response)
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in loadCodeAssist endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in loadCodeAssist endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 共用的 onboardUser 处理函数
|
||||
async function handleOnboardUser(req, res) {
|
||||
try {
|
||||
const { tierId, cloudaicompanionProject, metadata } = req.body;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const { tierId, cloudaicompanionProject, metadata } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
|
||||
const requestedModel = req.body.model
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId)
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`OnboardUser request (${version})`, {
|
||||
tierId: tierId || 'not provided',
|
||||
cloudaicompanionProject: cloudaicompanionProject || null,
|
||||
metadata: metadata || {},
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
|
||||
// 如果提供了完整参数,直接调用onboardUser
|
||||
if (tierId && metadata) {
|
||||
const response = await geminiAccountService.onboardUser(client, tierId, cloudaicompanionProject, metadata);
|
||||
res.json(response);
|
||||
const response = await geminiAccountService.onboardUser(
|
||||
client,
|
||||
tierId,
|
||||
cloudaicompanionProject,
|
||||
metadata
|
||||
)
|
||||
res.json(response)
|
||||
} else {
|
||||
// 否则执行完整的setupUser流程
|
||||
const response = await geminiAccountService.setupUser(client, cloudaicompanionProject, metadata);
|
||||
res.json(response);
|
||||
const response = await geminiAccountService.setupUser(
|
||||
client,
|
||||
cloudaicompanionProject,
|
||||
metadata
|
||||
)
|
||||
res.json(response)
|
||||
}
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in onboardUser endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in onboardUser endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,9 +400,9 @@ async function handleOnboardUser(req, res) {
|
||||
async function handleCountTokens(req, res) {
|
||||
try {
|
||||
// 处理请求体结构,支持直接 contents 或 request.contents
|
||||
const requestData = req.body.request || req.body;
|
||||
const { contents, model = 'gemini-2.0-flash-exp' } = requestData;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const requestData = req.body.request || req.body
|
||||
const { contents, model = 'gemini-2.0-flash-exp' } = requestData
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 验证必需参数
|
||||
if (!contents || !Array.isArray(contents)) {
|
||||
@@ -386,49 +411,54 @@ async function handleCountTokens(req, res) {
|
||||
message: 'Contents array is required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 使用统一调度选择账号
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId)
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`CountTokens request (${version})`, {
|
||||
model: model,
|
||||
model,
|
||||
contentsLength: contents.length,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const response = await geminiAccountService.countTokens(client, contents, model);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const response = await geminiAccountService.countTokens(client, contents, model)
|
||||
|
||||
res.json(response);
|
||||
res.json(response)
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in countTokens endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in countTokens endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 共用的 generateContent 处理函数
|
||||
async function handleGenerateContent(req, res) {
|
||||
try {
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 处理不同格式的请求
|
||||
let actualRequestData = requestData;
|
||||
let actualRequestData = requestData
|
||||
if (!requestData) {
|
||||
if (req.body.messages) {
|
||||
// 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象
|
||||
actualRequestData = {
|
||||
contents: req.body.messages.map(msg => ({
|
||||
contents: req.body.messages.map((msg) => ({
|
||||
role: msg.role === 'assistant' ? 'model' : msg.role,
|
||||
parts: [{ text: msg.content }]
|
||||
})),
|
||||
@@ -438,10 +468,10 @@ async function handleGenerateContent(req, res) {
|
||||
topP: req.body.top_p !== undefined ? req.body.top_p : 0.95,
|
||||
topK: req.body.top_k !== undefined ? req.body.top_k : 40
|
||||
}
|
||||
};
|
||||
}
|
||||
} else if (req.body.contents) {
|
||||
// 直接的 Gemini 格式请求(没有 request 包装)
|
||||
actualRequestData = req.body;
|
||||
actualRequestData = req.body
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,35 +482,39 @@ async function handleGenerateContent(req, res) {
|
||||
message: 'Request contents are required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 使用统一调度选择账号
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const { accessToken, refreshToken } = account;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
const { accessToken, refreshToken } = account
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`GenerateContent request (${version})`, {
|
||||
model: model,
|
||||
model,
|
||||
userPromptId: user_prompt_id,
|
||||
projectId: project || account.projectId,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const response = await geminiAccountService.generateContent(
|
||||
client,
|
||||
{ model, request: actualRequestData },
|
||||
user_prompt_id,
|
||||
project || account.projectId,
|
||||
req.apiKey?.id // 使用 API Key ID 作为 session ID
|
||||
);
|
||||
)
|
||||
|
||||
// 记录使用统计
|
||||
if (response?.response?.usageMetadata) {
|
||||
try {
|
||||
const usage = response.response.usageMetadata;
|
||||
const usage = response.response.usageMetadata
|
||||
await apiKeyService.recordUsage(
|
||||
req.apiKey.id,
|
||||
usage.promptTokenCount || 0,
|
||||
@@ -489,42 +523,45 @@ async function handleGenerateContent(req, res) {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini usage - Input: ${usage.promptTokenCount}, Output: ${usage.candidatesTokenCount}, Total: ${usage.totalTokenCount}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini usage - Input: ${usage.promptTokenCount}, Output: ${usage.candidatesTokenCount}, Total: ${usage.totalTokenCount}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
res.json(response);
|
||||
res.json(response)
|
||||
} catch (error) {
|
||||
console.log(321, error.response);
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in generateContent endpoint (${version})`, { error: error.message });
|
||||
console.log(321, error.response)
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in generateContent endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 共用的 streamGenerateContent 处理函数
|
||||
async function handleStreamGenerateContent(req, res) {
|
||||
let abortController = null;
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 处理不同格式的请求
|
||||
let actualRequestData = requestData;
|
||||
let actualRequestData = requestData
|
||||
if (!requestData) {
|
||||
if (req.body.messages) {
|
||||
// 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象
|
||||
actualRequestData = {
|
||||
contents: req.body.messages.map(msg => ({
|
||||
contents: req.body.messages.map((msg) => ({
|
||||
role: msg.role === 'assistant' ? 'model' : msg.role,
|
||||
parts: [{ text: msg.content }]
|
||||
})),
|
||||
@@ -534,10 +571,10 @@ async function handleStreamGenerateContent(req, res) {
|
||||
topP: req.body.top_p !== undefined ? req.body.top_p : 0.95,
|
||||
topK: req.body.top_k !== undefined ? req.body.top_k : 40
|
||||
}
|
||||
};
|
||||
}
|
||||
} else if (req.body.contents) {
|
||||
// 直接的 Gemini 格式请求(没有 request 包装)
|
||||
actualRequestData = req.body;
|
||||
actualRequestData = req.body
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,34 +585,38 @@ async function handleStreamGenerateContent(req, res) {
|
||||
message: 'Request contents are required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 使用统一调度选择账号
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const { accessToken, refreshToken } = account;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
const { accessToken, refreshToken } = account
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`StreamGenerateContent request (${version})`, {
|
||||
model: model,
|
||||
model,
|
||||
userPromptId: user_prompt_id,
|
||||
projectId: project || account.projectId,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开连接
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('Client disconnected, aborting stream request');
|
||||
abortController.abort();
|
||||
logger.info('Client disconnected, aborting stream request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const streamResponse = await geminiAccountService.generateContentStream(
|
||||
client,
|
||||
{ model, request: actualRequestData },
|
||||
@@ -583,48 +624,48 @@ async function handleStreamGenerateContent(req, res) {
|
||||
project || account.projectId,
|
||||
req.apiKey?.id, // 使用 API Key ID 作为 session ID
|
||||
abortController.signal // 传递中止信号
|
||||
);
|
||||
)
|
||||
|
||||
// 设置 SSE 响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 处理流式响应并捕获usage数据
|
||||
let buffer = '';
|
||||
let buffer = ''
|
||||
let totalUsage = {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
let usageReported = false;
|
||||
}
|
||||
const usageReported = false
|
||||
|
||||
streamResponse.on('data', (chunk) => {
|
||||
try {
|
||||
const chunkStr = chunk.toString();
|
||||
|
||||
const chunkStr = chunk.toString()
|
||||
|
||||
// 直接转发数据到客户端
|
||||
if (!res.destroyed) {
|
||||
res.write(chunkStr);
|
||||
res.write(chunkStr)
|
||||
}
|
||||
|
||||
// 同时解析数据以捕获usage信息
|
||||
buffer += chunkStr;
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || '';
|
||||
buffer += chunkStr
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ') && line.length > 6) {
|
||||
try {
|
||||
const jsonStr = line.slice(6);
|
||||
const jsonStr = line.slice(6)
|
||||
if (jsonStr && jsonStr !== '[DONE]') {
|
||||
const data = JSON.parse(jsonStr);
|
||||
|
||||
const data = JSON.parse(jsonStr)
|
||||
|
||||
// 从响应中提取usage数据
|
||||
if (data.response?.usageMetadata) {
|
||||
totalUsage = data.response.usageMetadata;
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage);
|
||||
totalUsage = data.response.usageMetadata
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -633,13 +674,13 @@ async function handleStreamGenerateContent(req, res) {
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing stream chunk:', error);
|
||||
logger.error('Error processing stream chunk:', error)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
streamResponse.on('end', async () => {
|
||||
logger.info('Stream completed successfully');
|
||||
|
||||
logger.info('Stream completed successfully')
|
||||
|
||||
// 记录使用统计
|
||||
if (!usageReported && totalUsage.totalTokenCount > 0) {
|
||||
try {
|
||||
@@ -651,33 +692,34 @@ async function handleStreamGenerateContent(req, res) {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
res.end();
|
||||
});
|
||||
|
||||
res.end()
|
||||
})
|
||||
|
||||
streamResponse.on('error', (error) => {
|
||||
logger.error('Stream error:', error);
|
||||
logger.error('Stream error:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Stream error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
} else {
|
||||
res.end();
|
||||
res.end()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in streamGenerateContent endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in streamGenerateContent endpoint (${version})`, { error: error.message })
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
@@ -685,29 +727,38 @@ async function handleStreamGenerateContent(req, res) {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 注册所有路由端点
|
||||
// v1internal 版本的端点
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist);
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, handleOnboardUser);
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, handleCountTokens);
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, handleGenerateContent);
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent);
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist)
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, handleOnboardUser)
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, handleCountTokens)
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, handleGenerateContent)
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent)
|
||||
|
||||
// v1beta 版本的端点 - 支持动态模型名称
|
||||
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist);
|
||||
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, handleOnboardUser);
|
||||
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, handleCountTokens);
|
||||
router.post('/v1beta/models/:modelName\\:generateContent', authenticateApiKey, handleGenerateContent);
|
||||
router.post('/v1beta/models/:modelName\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent);
|
||||
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist)
|
||||
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, handleOnboardUser)
|
||||
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, handleCountTokens)
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:generateContent',
|
||||
authenticateApiKey,
|
||||
handleGenerateContent
|
||||
)
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:streamGenerateContent',
|
||||
authenticateApiKey,
|
||||
handleStreamGenerateContent
|
||||
)
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -3,41 +3,41 @@
|
||||
* 提供 OpenAI 格式的 API 接口,内部转发到 Claude
|
||||
*/
|
||||
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const logger = require('../utils/logger');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const claudeRelayService = require('../services/claudeRelayService');
|
||||
const openaiToClaude = require('../services/openaiToClaude');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler');
|
||||
const claudeCodeHeadersService = require('../services/claudeCodeHeadersService');
|
||||
const sessionHelper = require('../utils/sessionHelper');
|
||||
const express = require('express')
|
||||
const router = express.Router()
|
||||
const fs = require('fs')
|
||||
const path = require('path')
|
||||
const logger = require('../utils/logger')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const claudeRelayService = require('../services/claudeRelayService')
|
||||
const openaiToClaude = require('../services/openaiToClaude')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler')
|
||||
const claudeCodeHeadersService = require('../services/claudeCodeHeadersService')
|
||||
const sessionHelper = require('../utils/sessionHelper')
|
||||
|
||||
// 加载模型定价数据
|
||||
let modelPricingData = {};
|
||||
let modelPricingData = {}
|
||||
try {
|
||||
const pricingPath = path.join(__dirname, '../../data/model_pricing.json');
|
||||
const pricingContent = fs.readFileSync(pricingPath, 'utf8');
|
||||
modelPricingData = JSON.parse(pricingContent);
|
||||
logger.info('✅ Model pricing data loaded successfully');
|
||||
const pricingPath = path.join(__dirname, '../../data/model_pricing.json')
|
||||
const pricingContent = fs.readFileSync(pricingPath, 'utf8')
|
||||
modelPricingData = JSON.parse(pricingContent)
|
||||
logger.info('✅ Model pricing data loaded successfully')
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to load model pricing data:', error);
|
||||
logger.error('❌ Failed to load model pricing data:', error)
|
||||
}
|
||||
|
||||
// 🔧 辅助函数:检查 API Key 权限
|
||||
function checkPermissions(apiKeyData, requiredPermission = 'claude') {
|
||||
const permissions = apiKeyData.permissions || 'all';
|
||||
return permissions === 'all' || permissions === requiredPermission;
|
||||
const permissions = apiKeyData.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 📋 OpenAI 兼容的模型列表端点
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'claude')) {
|
||||
return res.status(403).json({
|
||||
@@ -46,9 +46,9 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Claude 模型列表 - 只返回 opus-4 和 sonnet-4
|
||||
let models = [
|
||||
{
|
||||
@@ -63,36 +63,36 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
created: 1736726400, // 2025-01-13
|
||||
owned_by: 'anthropic'
|
||||
}
|
||||
];
|
||||
|
||||
]
|
||||
|
||||
// 如果启用了模型限制,过滤模型列表
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels?.length > 0) {
|
||||
models = models.filter(model => apiKeyData.restrictedModels.includes(model.id));
|
||||
models = models.filter((model) => apiKeyData.restrictedModels.includes(model.id))
|
||||
}
|
||||
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get OpenAI-Claude models:', error);
|
||||
logger.error('❌ Failed to get OpenAI-Claude models:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 📄 OpenAI 兼容的模型详情端点
|
||||
router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const modelId = req.params.model;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
const modelId = req.params.model
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'claude')) {
|
||||
return res.status(403).json({
|
||||
@@ -101,9 +101,9 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels?.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(modelId)) {
|
||||
@@ -113,16 +113,16 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_found'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 从 model_pricing.json 获取模型信息
|
||||
const modelData = modelPricingData[modelId];
|
||||
|
||||
const modelData = modelPricingData[modelId]
|
||||
|
||||
// 构建标准 OpenAI 格式的模型响应
|
||||
let modelInfo;
|
||||
|
||||
let modelInfo
|
||||
|
||||
if (modelData) {
|
||||
// 如果在 pricing 文件中找到了模型
|
||||
modelInfo = {
|
||||
@@ -133,7 +133,7 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
permission: [],
|
||||
root: modelId,
|
||||
parent: null
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// 如果没找到,返回默认信息(但仍保持正确格式)
|
||||
modelInfo = {
|
||||
@@ -144,28 +144,28 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
permission: [],
|
||||
root: modelId,
|
||||
parent: null
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
res.json(modelInfo);
|
||||
|
||||
|
||||
res.json(modelInfo)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get model details:', error);
|
||||
logger.error('❌ Failed to get model details:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve model details',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 🔧 处理聊天完成请求的核心函数
|
||||
async function handleChatCompletion(req, res, apiKeyData) {
|
||||
const startTime = Date.now();
|
||||
let abortController = null;
|
||||
|
||||
const startTime = Date.now()
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'claude')) {
|
||||
@@ -175,20 +175,20 @@ async function handleChatCompletion(req, res, apiKeyData) {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 记录原始请求
|
||||
logger.debug('📥 Received OpenAI format request:', {
|
||||
model: req.body.model,
|
||||
messageCount: req.body.messages?.length,
|
||||
stream: req.body.stream,
|
||||
maxTokens: req.body.max_tokens
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 转换 OpenAI 请求为 Claude 格式
|
||||
const claudeRequest = openaiToClaude.convertRequest(req.body);
|
||||
|
||||
const claudeRequest = openaiToClaude.convertRequest(req.body)
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels?.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(claudeRequest.model)) {
|
||||
@@ -198,114 +198,119 @@ async function handleChatCompletion(req, res, apiKeyData) {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_allowed'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(claudeRequest);
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(claudeRequest)
|
||||
|
||||
// 选择可用的Claude账户
|
||||
const accountSelection = await unifiedClaudeScheduler.selectAccountForApiKey(apiKeyData, sessionHash, claudeRequest.model);
|
||||
const accountId = accountSelection.accountId;
|
||||
|
||||
const accountSelection = await unifiedClaudeScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
sessionHash,
|
||||
claudeRequest.model
|
||||
)
|
||||
const { accountId } = accountSelection
|
||||
|
||||
// 获取该账号存储的 Claude Code headers
|
||||
const claudeCodeHeaders = await claudeCodeHeadersService.getAccountHeaders(accountId);
|
||||
|
||||
const claudeCodeHeaders = await claudeCodeHeadersService.getAccountHeaders(accountId)
|
||||
|
||||
logger.debug(`📋 Using Claude Code headers for account ${accountId}:`, {
|
||||
userAgent: claudeCodeHeaders['user-agent']
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 处理流式请求
|
||||
if (claudeRequest.stream) {
|
||||
logger.info(`🌊 Processing OpenAI stream request for model: ${req.body.model}`);
|
||||
|
||||
logger.info(`🌊 Processing OpenAI stream request for model: ${req.body.model}`)
|
||||
|
||||
// 设置 SSE 响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('🔌 Client disconnected, aborting Claude request');
|
||||
abortController.abort();
|
||||
logger.info('🔌 Client disconnected, aborting Claude request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 使用转换后的响应流 (使用 OAuth-only beta header,添加 Claude Code 必需的 headers)
|
||||
await claudeRelayService.relayStreamRequestWithUsageCapture(
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
res,
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
res,
|
||||
claudeCodeHeaders,
|
||||
(usage) => {
|
||||
// 记录使用统计
|
||||
if (usage && usage.input_tokens !== undefined && usage.output_tokens !== undefined) {
|
||||
const inputTokens = usage.input_tokens || 0;
|
||||
const outputTokens = usage.output_tokens || 0;
|
||||
const cacheCreateTokens = usage.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0;
|
||||
const model = usage.model || claudeRequest.model;
|
||||
|
||||
apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
accountId
|
||||
).catch(error => {
|
||||
logger.error('❌ Failed to record usage:', error);
|
||||
});
|
||||
const inputTokens = usage.input_tokens || 0
|
||||
const outputTokens = usage.output_tokens || 0
|
||||
const cacheCreateTokens = usage.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0
|
||||
const model = usage.model || claudeRequest.model
|
||||
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
apiKeyData.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
accountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record usage:', error)
|
||||
})
|
||||
}
|
||||
},
|
||||
// 流转换器
|
||||
(() => {
|
||||
// 为每个请求创建独立的会话ID
|
||||
const sessionId = `chatcmpl-${Math.random().toString(36).substring(2, 15)}${Math.random().toString(36).substring(2, 15)}`;
|
||||
return (chunk) => {
|
||||
return openaiToClaude.convertStreamChunk(chunk, req.body.model, sessionId);
|
||||
};
|
||||
const sessionId = `chatcmpl-${Math.random().toString(36).substring(2, 15)}${Math.random().toString(36).substring(2, 15)}`
|
||||
return (chunk) => openaiToClaude.convertStreamChunk(chunk, req.body.model, sessionId)
|
||||
})(),
|
||||
{ betaHeader: 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14' }
|
||||
);
|
||||
|
||||
{
|
||||
betaHeader:
|
||||
'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14'
|
||||
}
|
||||
)
|
||||
} else {
|
||||
// 非流式请求
|
||||
logger.info(`📄 Processing OpenAI non-stream request for model: ${req.body.model}`);
|
||||
|
||||
logger.info(`📄 Processing OpenAI non-stream request for model: ${req.body.model}`)
|
||||
|
||||
// 发送请求到 Claude (使用 OAuth-only beta header,添加 Claude Code 必需的 headers)
|
||||
const claudeResponse = await claudeRelayService.relayRequest(
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
req,
|
||||
res,
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
req,
|
||||
res,
|
||||
claudeCodeHeaders,
|
||||
{ betaHeader: 'oauth-2025-04-20' }
|
||||
);
|
||||
|
||||
)
|
||||
|
||||
// 解析 Claude 响应
|
||||
let claudeData;
|
||||
let claudeData
|
||||
try {
|
||||
claudeData = JSON.parse(claudeResponse.body);
|
||||
claudeData = JSON.parse(claudeResponse.body)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to parse Claude response:', error);
|
||||
logger.error('❌ Failed to parse Claude response:', error)
|
||||
return res.status(502).json({
|
||||
error: {
|
||||
message: 'Invalid response from Claude API',
|
||||
type: 'api_error',
|
||||
code: 'invalid_response'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 处理错误响应
|
||||
if (claudeResponse.statusCode >= 400) {
|
||||
return res.status(claudeResponse.statusCode).json({
|
||||
@@ -314,64 +319,66 @@ async function handleChatCompletion(req, res, apiKeyData) {
|
||||
type: claudeData.error?.type || 'api_error',
|
||||
code: claudeData.error?.code || 'unknown_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 转换为 OpenAI 格式
|
||||
const openaiResponse = openaiToClaude.convertResponse(claudeData, req.body.model);
|
||||
|
||||
const openaiResponse = openaiToClaude.convertResponse(claudeData, req.body.model)
|
||||
|
||||
// 记录使用统计
|
||||
if (claudeData.usage) {
|
||||
const usage = claudeData.usage;
|
||||
apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
usage.input_tokens || 0,
|
||||
usage.output_tokens || 0,
|
||||
usage.cache_creation_input_tokens || 0,
|
||||
usage.cache_read_input_tokens || 0,
|
||||
claudeRequest.model,
|
||||
accountId
|
||||
).catch(error => {
|
||||
logger.error('❌ Failed to record usage:', error);
|
||||
});
|
||||
const { usage } = claudeData
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
apiKeyData.id,
|
||||
usage.input_tokens || 0,
|
||||
usage.output_tokens || 0,
|
||||
usage.cache_creation_input_tokens || 0,
|
||||
usage.cache_read_input_tokens || 0,
|
||||
claudeRequest.model,
|
||||
accountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record usage:', error)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 返回 OpenAI 格式响应
|
||||
res.json(openaiResponse);
|
||||
res.json(openaiResponse)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`✅ OpenAI-Claude request completed in ${duration}ms`);
|
||||
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`✅ OpenAI-Claude request completed in ${duration}ms`)
|
||||
} catch (error) {
|
||||
logger.error('❌ OpenAI-Claude request error:', error);
|
||||
|
||||
const status = error.status || 500;
|
||||
logger.error('❌ OpenAI-Claude request error:', error)
|
||||
|
||||
const status = error.status || 500
|
||||
res.status(status).json({
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 🚀 OpenAI 兼容的聊天完成端点
|
||||
router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
await handleChatCompletion(req, res, req.apiKey);
|
||||
});
|
||||
await handleChatCompletion(req, res, req.apiKey)
|
||||
})
|
||||
|
||||
// 🔧 OpenAI 兼容的 completions 端点(传统格式,转换为 chat 格式)
|
||||
router.post('/v1/completions', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 验证必需参数
|
||||
if (!req.body.prompt) {
|
||||
return res.status(400).json({
|
||||
@@ -380,11 +387,11 @@ router.post('/v1/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'invalid_request'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 将传统 completions 格式转换为 chat 格式
|
||||
const originalBody = req.body;
|
||||
const originalBody = req.body
|
||||
req.body = {
|
||||
model: originalBody.model,
|
||||
messages: [
|
||||
@@ -403,21 +410,21 @@ router.post('/v1/completions', authenticateApiKey, async (req, res) => {
|
||||
frequency_penalty: originalBody.frequency_penalty,
|
||||
logit_bias: originalBody.logit_bias,
|
||||
user: originalBody.user
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 使用共享的处理函数
|
||||
await handleChatCompletion(req, res, apiKeyData);
|
||||
|
||||
await handleChatCompletion(req, res, apiKeyData)
|
||||
} catch (error) {
|
||||
logger.error('❌ OpenAI completions error:', error);
|
||||
logger.error('❌ OpenAI completions error:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to process completion request',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const logger = require('../utils/logger');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const geminiAccountService = require('../services/geminiAccountService');
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler');
|
||||
const { getAvailableModels } = require('../services/geminiRelayService');
|
||||
const crypto = require('crypto');
|
||||
const express = require('express')
|
||||
const router = express.Router()
|
||||
const logger = require('../utils/logger')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const geminiAccountService = require('../services/geminiAccountService')
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler')
|
||||
const { getAvailableModels } = require('../services/geminiRelayService')
|
||||
const crypto = require('crypto')
|
||||
|
||||
// 生成会话哈希
|
||||
function generateSessionHash(req) {
|
||||
@@ -13,167 +13,182 @@ function generateSessionHash(req) {
|
||||
req.headers['user-agent'],
|
||||
req.ip,
|
||||
req.headers['authorization']?.substring(0, 20)
|
||||
].filter(Boolean).join(':');
|
||||
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex');
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(':')
|
||||
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex')
|
||||
}
|
||||
|
||||
// 检查 API Key 权限
|
||||
function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
|
||||
const permissions = apiKeyData.permissions || 'all';
|
||||
return permissions === 'all' || permissions === requiredPermission;
|
||||
const permissions = apiKeyData.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 转换 OpenAI 消息格式到 Gemini 格式
|
||||
function convertMessagesToGemini(messages) {
|
||||
const contents = [];
|
||||
let systemInstruction = '';
|
||||
|
||||
const contents = []
|
||||
let systemInstruction = ''
|
||||
|
||||
// 辅助函数:提取文本内容
|
||||
function extractTextContent(content) {
|
||||
// 处理 null 或 undefined
|
||||
if (content == null) {
|
||||
return '';
|
||||
if (content === null || content === undefined) {
|
||||
return ''
|
||||
}
|
||||
|
||||
|
||||
// 处理字符串
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
return content
|
||||
}
|
||||
|
||||
|
||||
// 处理数组格式的内容
|
||||
if (Array.isArray(content)) {
|
||||
return content.map(item => {
|
||||
if (item == null) return '';
|
||||
if (typeof item === 'string') {
|
||||
return item;
|
||||
}
|
||||
if (typeof item === 'object') {
|
||||
// 处理 {type: 'text', text: '...'} 格式
|
||||
if (item.type === 'text' && item.text) {
|
||||
return item.text;
|
||||
return content
|
||||
.map((item) => {
|
||||
if (item === null || item === undefined) {
|
||||
return ''
|
||||
}
|
||||
// 处理 {text: '...'} 格式
|
||||
if (item.text) {
|
||||
return item.text;
|
||||
if (typeof item === 'string') {
|
||||
return item
|
||||
}
|
||||
// 处理嵌套的对象或数组
|
||||
if (item.content) {
|
||||
return extractTextContent(item.content);
|
||||
if (typeof item === 'object') {
|
||||
// 处理 {type: 'text', text: '...'} 格式
|
||||
if (item.type === 'text' && item.text) {
|
||||
return item.text
|
||||
}
|
||||
// 处理 {text: '...'} 格式
|
||||
if (item.text) {
|
||||
return item.text
|
||||
}
|
||||
// 处理嵌套的对象或数组
|
||||
if (item.content) {
|
||||
return extractTextContent(item.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}).join('');
|
||||
return ''
|
||||
})
|
||||
.join('')
|
||||
}
|
||||
|
||||
|
||||
// 处理对象格式的内容
|
||||
if (typeof content === 'object') {
|
||||
// 处理 {text: '...'} 格式
|
||||
if (content.text) {
|
||||
return content.text;
|
||||
return content.text
|
||||
}
|
||||
// 处理 {content: '...'} 格式
|
||||
if (content.content) {
|
||||
return extractTextContent(content.content);
|
||||
return extractTextContent(content.content)
|
||||
}
|
||||
// 处理 {parts: [{text: '...'}]} 格式
|
||||
if (content.parts && Array.isArray(content.parts)) {
|
||||
return content.parts.map(part => {
|
||||
if (part && part.text) {
|
||||
return part.text;
|
||||
}
|
||||
return '';
|
||||
}).join('');
|
||||
return content.parts
|
||||
.map((part) => {
|
||||
if (part && part.text) {
|
||||
return part.text
|
||||
}
|
||||
return ''
|
||||
})
|
||||
.join('')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 最后的后备选项:只有在内容确实不为空且有意义时才转换为字符串
|
||||
if (content !== undefined && content !== null && content !== '' && typeof content !== 'object') {
|
||||
return String(content);
|
||||
if (
|
||||
content !== undefined &&
|
||||
content !== null &&
|
||||
content !== '' &&
|
||||
typeof content !== 'object'
|
||||
) {
|
||||
return String(content)
|
||||
}
|
||||
|
||||
return '';
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
|
||||
for (const message of messages) {
|
||||
const textContent = extractTextContent(message.content);
|
||||
|
||||
const textContent = extractTextContent(message.content)
|
||||
|
||||
if (message.role === 'system') {
|
||||
systemInstruction += (systemInstruction ? '\n\n' : '') + textContent;
|
||||
systemInstruction += (systemInstruction ? '\n\n' : '') + textContent
|
||||
} else if (message.role === 'user') {
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: textContent }]
|
||||
});
|
||||
})
|
||||
} else if (message.role === 'assistant') {
|
||||
contents.push({
|
||||
role: 'model',
|
||||
parts: [{ text: textContent }]
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return { contents, systemInstruction };
|
||||
|
||||
return { contents, systemInstruction }
|
||||
}
|
||||
|
||||
// 转换 Gemini 响应到 OpenAI 格式
|
||||
function convertGeminiResponseToOpenAI(geminiResponse, model, stream = false) {
|
||||
if (stream) {
|
||||
// 处理流式响应 - 原样返回 SSE 数据
|
||||
return geminiResponse;
|
||||
return geminiResponse
|
||||
} else {
|
||||
// 非流式响应转换
|
||||
// 处理嵌套的 response 结构
|
||||
const actualResponse = geminiResponse.response || geminiResponse;
|
||||
|
||||
const actualResponse = geminiResponse.response || geminiResponse
|
||||
|
||||
if (actualResponse.candidates && actualResponse.candidates.length > 0) {
|
||||
const candidate = actualResponse.candidates[0];
|
||||
const content = candidate.content?.parts?.[0]?.text || '';
|
||||
const finishReason = candidate.finishReason?.toLowerCase() || 'stop';
|
||||
const candidate = actualResponse.candidates[0]
|
||||
const content = candidate.content?.parts?.[0]?.text || ''
|
||||
const finishReason = candidate.finishReason?.toLowerCase() || 'stop'
|
||||
|
||||
// 计算 token 使用量
|
||||
const usage = actualResponse.usageMetadata || {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: content
|
||||
},
|
||||
finish_reason: finishReason
|
||||
}],
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content
|
||||
},
|
||||
finish_reason: finishReason
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: usage.promptTokenCount,
|
||||
completion_tokens: usage.candidatesTokenCount,
|
||||
total_tokens: usage.totalTokenCount
|
||||
}
|
||||
};
|
||||
}
|
||||
} else {
|
||||
throw new Error('No response from Gemini');
|
||||
throw new Error('No response from Gemini')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI 兼容的聊天完成端点
|
||||
router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
const startTime = Date.now();
|
||||
let abortController = null;
|
||||
let account = null; // Declare account outside try block for error handling
|
||||
let accountSelection = null; // Declare accountSelection for error handling
|
||||
let sessionHash = null; // Declare sessionHash for error handling
|
||||
|
||||
const startTime = Date.now()
|
||||
let abortController = null
|
||||
let account = null // Declare account outside try block for error handling
|
||||
let accountSelection = null // Declare accountSelection for error handling
|
||||
let sessionHash = null // Declare sessionHash for error handling
|
||||
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
return res.status(403).json({
|
||||
@@ -182,25 +197,25 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
// 处理请求体结构 - 支持多种格式
|
||||
let requestBody = req.body;
|
||||
|
||||
let requestBody = req.body
|
||||
|
||||
// 如果请求体被包装在 body 字段中,解包它
|
||||
if (req.body.body && typeof req.body.body === 'object') {
|
||||
requestBody = req.body.body;
|
||||
requestBody = req.body.body
|
||||
}
|
||||
|
||||
|
||||
// 从 URL 路径中提取模型信息(如果存在)
|
||||
let urlModel = null;
|
||||
const urlPath = req.body?.config?.url || req.originalUrl || req.url;
|
||||
const modelMatch = urlPath.match(/\/([^/]+):(?:stream)?[Gg]enerateContent/);
|
||||
let urlModel = null
|
||||
const urlPath = req.body?.config?.url || req.originalUrl || req.url
|
||||
const modelMatch = urlPath.match(/\/([^/]+):(?:stream)?[Gg]enerateContent/)
|
||||
if (modelMatch) {
|
||||
urlModel = modelMatch[1];
|
||||
logger.debug(`Extracted model from URL: ${urlModel}`);
|
||||
urlModel = modelMatch[1]
|
||||
logger.debug(`Extracted model from URL: ${urlModel}`)
|
||||
}
|
||||
|
||||
|
||||
// 提取请求参数
|
||||
const {
|
||||
messages: requestMessages,
|
||||
@@ -209,19 +224,19 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
temperature = 0.7,
|
||||
max_tokens = 4096,
|
||||
stream = false
|
||||
} = requestBody;
|
||||
|
||||
} = requestBody
|
||||
|
||||
// 检查URL中是否包含stream标识
|
||||
const isStreamFromUrl = urlPath && urlPath.includes('streamGenerateContent');
|
||||
const actualStream = stream || isStreamFromUrl;
|
||||
const isStreamFromUrl = urlPath && urlPath.includes('streamGenerateContent')
|
||||
const actualStream = stream || isStreamFromUrl
|
||||
|
||||
// 优先使用 URL 中的模型,其次是请求体中的模型
|
||||
const model = urlModel || bodyModel;
|
||||
const model = urlModel || bodyModel
|
||||
|
||||
// 支持两种格式: OpenAI 的 messages 或 Gemini 的 contents
|
||||
let messages = requestMessages;
|
||||
let messages = requestMessages
|
||||
if (requestContents && Array.isArray(requestContents)) {
|
||||
messages = requestContents;
|
||||
messages = requestContents
|
||||
}
|
||||
|
||||
// 验证必需参数
|
||||
@@ -232,9 +247,9 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'invalid_request'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(model)) {
|
||||
@@ -244,13 +259,13 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_allowed'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 转换消息格式
|
||||
const { contents: geminiContents, systemInstruction } = convertMessagesToGemini(messages);
|
||||
|
||||
const { contents: geminiContents, systemInstruction } = convertMessagesToGemini(messages)
|
||||
|
||||
// 构建 Gemini 请求体
|
||||
const geminiRequestBody = {
|
||||
contents: geminiContents,
|
||||
@@ -259,24 +274,28 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
maxOutputTokens: max_tokens,
|
||||
candidateCount: 1
|
||||
}
|
||||
};
|
||||
|
||||
if (systemInstruction) {
|
||||
geminiRequestBody.systemInstruction = { parts: [{ text: systemInstruction }] };
|
||||
}
|
||||
|
||||
|
||||
if (systemInstruction) {
|
||||
geminiRequestBody.systemInstruction = { parts: [{ text: systemInstruction }] }
|
||||
}
|
||||
|
||||
// 生成会话哈希用于粘性会话
|
||||
sessionHash = generateSessionHash(req);
|
||||
|
||||
sessionHash = generateSessionHash(req)
|
||||
|
||||
// 选择可用的 Gemini 账户
|
||||
try {
|
||||
accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(apiKeyData, sessionHash, model);
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId);
|
||||
accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId)
|
||||
} catch (error) {
|
||||
logger.error('Failed to select Gemini account:', error);
|
||||
account = null;
|
||||
logger.error('Failed to select Gemini account:', error)
|
||||
account = null
|
||||
}
|
||||
|
||||
|
||||
if (!account) {
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
@@ -284,35 +303,38 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'service_unavailable',
|
||||
code: 'service_unavailable'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`);
|
||||
|
||||
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`)
|
||||
|
||||
// 标记账户被使用
|
||||
await geminiAccountService.markAccountUsed(account.id);
|
||||
|
||||
await geminiAccountService.markAccountUsed(account.id)
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开连接
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('Client disconnected, aborting Gemini request');
|
||||
abortController.abort();
|
||||
logger.info('Client disconnected, aborting Gemini request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 获取OAuth客户端
|
||||
const client = await geminiAccountService.getOauthClient(account.accessToken, account.refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(
|
||||
account.accessToken,
|
||||
account.refreshToken
|
||||
)
|
||||
if (actualStream) {
|
||||
// 流式响应
|
||||
logger.info('StreamGenerateContent request', {
|
||||
model: model,
|
||||
model,
|
||||
projectId: account.projectId,
|
||||
apiKeyId: apiKeyData.id
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
const streamResponse = await geminiAccountService.generateContentStream(
|
||||
client,
|
||||
{ model, request: geminiRequestBody },
|
||||
@@ -320,93 +342,101 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
account.projectId, // 使用有权限的项目ID
|
||||
apiKeyData.id, // 使用 API Key ID 作为 session ID
|
||||
abortController.signal // 传递中止信号
|
||||
);
|
||||
|
||||
)
|
||||
|
||||
// 设置流式响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 处理流式响应,转换为 OpenAI 格式
|
||||
let buffer = '';
|
||||
|
||||
let buffer = ''
|
||||
|
||||
// 发送初始的空消息,符合 OpenAI 流式格式
|
||||
const initialChunk = {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: { role: 'assistant' },
|
||||
finish_reason: null
|
||||
}]
|
||||
};
|
||||
res.write(`data: ${JSON.stringify(initialChunk)}\n\n`);
|
||||
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { role: 'assistant' },
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(initialChunk)}\n\n`)
|
||||
|
||||
// 用于收集usage数据
|
||||
let totalUsage = {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
let usageReported = false;
|
||||
}
|
||||
const usageReported = false
|
||||
|
||||
streamResponse.on('data', (chunk) => {
|
||||
try {
|
||||
const chunkStr = chunk.toString();
|
||||
|
||||
const chunkStr = chunk.toString()
|
||||
|
||||
if (!chunkStr.trim()) {
|
||||
return;
|
||||
return
|
||||
}
|
||||
|
||||
buffer += chunkStr;
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || ''; // 保留最后一个不完整的行
|
||||
|
||||
|
||||
buffer += chunkStr
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || '' // 保留最后一个不完整的行
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.trim()) continue;
|
||||
|
||||
// 处理 SSE 格式
|
||||
let jsonData = line;
|
||||
if (line.startsWith('data: ')) {
|
||||
jsonData = line.substring(6).trim();
|
||||
if (!line.trim()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (!jsonData || jsonData === '[DONE]') continue;
|
||||
|
||||
|
||||
// 处理 SSE 格式
|
||||
let jsonData = line
|
||||
if (line.startsWith('data: ')) {
|
||||
jsonData = line.substring(6).trim()
|
||||
}
|
||||
|
||||
if (!jsonData || jsonData === '[DONE]') {
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonData);
|
||||
|
||||
const data = JSON.parse(jsonData)
|
||||
|
||||
// 捕获usage数据
|
||||
if (data.response?.usageMetadata) {
|
||||
totalUsage = data.response.usageMetadata;
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage);
|
||||
totalUsage = data.response.usageMetadata
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage)
|
||||
}
|
||||
|
||||
|
||||
// 转换为 OpenAI 流式格式
|
||||
if (data.response?.candidates && data.response.candidates.length > 0) {
|
||||
const candidate = data.response.candidates[0];
|
||||
const content = candidate.content?.parts?.[0]?.text || '';
|
||||
const finishReason = candidate.finishReason;
|
||||
|
||||
const candidate = data.response.candidates[0]
|
||||
const content = candidate.content?.parts?.[0]?.text || ''
|
||||
const { finishReason } = candidate
|
||||
|
||||
// 只有当有内容或者是结束标记时才发送数据
|
||||
if (content || finishReason === 'STOP') {
|
||||
const openaiChunk = {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: content ? { content: content } : {},
|
||||
finish_reason: finishReason === 'STOP' ? 'stop' : null
|
||||
}]
|
||||
};
|
||||
|
||||
res.write(`data: ${JSON.stringify(openaiChunk)}\n\n`);
|
||||
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: content ? { content } : {},
|
||||
finish_reason: finishReason === 'STOP' ? 'stop' : null
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
res.write(`data: ${JSON.stringify(openaiChunk)}\n\n`)
|
||||
|
||||
// 如果结束了,添加 usage 信息并发送最终的 [DONE]
|
||||
if (finishReason === 'STOP') {
|
||||
// 如果有 usage 数据,添加到最后一个 chunk
|
||||
@@ -415,48 +445,50 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: 'stop'
|
||||
}],
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: 'stop'
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: data.response.usageMetadata.promptTokenCount || 0,
|
||||
completion_tokens: data.response.usageMetadata.candidatesTokenCount || 0,
|
||||
total_tokens: data.response.usageMetadata.totalTokenCount || 0
|
||||
}
|
||||
};
|
||||
res.write(`data: ${JSON.stringify(usageChunk)}\n\n`);
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(usageChunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.write('data: [DONE]\n\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug('Error parsing JSON line:', e.message);
|
||||
logger.debug('Error parsing JSON line:', e.message)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Stream processing error:', error);
|
||||
logger.error('Stream processing error:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Stream error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
streamResponse.on('end', async () => {
|
||||
logger.info('Stream completed successfully');
|
||||
|
||||
logger.info('Stream completed successfully')
|
||||
|
||||
// 记录使用统计
|
||||
if (!usageReported && totalUsage.totalTokenCount > 0) {
|
||||
try {
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
await apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
totalUsage.promptTokenCount || 0,
|
||||
@@ -465,59 +497,60 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.write('data: [DONE]\n\n')
|
||||
}
|
||||
res.end();
|
||||
});
|
||||
|
||||
res.end()
|
||||
})
|
||||
|
||||
streamResponse.on('error', (error) => {
|
||||
logger.error('Stream error:', error);
|
||||
logger.error('Stream error:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Stream error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
} else {
|
||||
// 如果已经开始发送流数据,发送错误事件
|
||||
res.write(`data: {"error": {"message": "${error.message || 'Stream error'}"}}\n\n`);
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.end();
|
||||
res.write(`data: {"error": {"message": "${error.message || 'Stream error'}"}}\n\n`)
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} else {
|
||||
// 非流式响应
|
||||
logger.info('GenerateContent request', {
|
||||
model: model,
|
||||
model,
|
||||
projectId: account.projectId,
|
||||
apiKeyId: apiKeyData.id
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
const response = await geminiAccountService.generateContent(
|
||||
client,
|
||||
{ model, request: geminiRequestBody },
|
||||
null, // user_prompt_id
|
||||
account.projectId, // 使用有权限的项目ID
|
||||
apiKeyData.id // 使用 API Key ID 作为 session ID
|
||||
);
|
||||
|
||||
)
|
||||
|
||||
// 转换为 OpenAI 格式并返回
|
||||
const openaiResponse = convertGeminiResponseToOpenAI(response, model, false);
|
||||
|
||||
const openaiResponse = convertGeminiResponseToOpenAI(response, model, false)
|
||||
|
||||
// 记录使用统计
|
||||
if (openaiResponse.usage) {
|
||||
try {
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
await apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
openaiResponse.usage.prompt_tokens || 0,
|
||||
@@ -526,53 +559,55 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini usage - Input: ${openaiResponse.usage.prompt_tokens}, Output: ${openaiResponse.usage.completion_tokens}, Total: ${openaiResponse.usage.total_tokens}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini usage - Input: ${openaiResponse.usage.prompt_tokens}, Output: ${openaiResponse.usage.completion_tokens}, Total: ${openaiResponse.usage.total_tokens}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
res.json(openaiResponse);
|
||||
|
||||
res.json(openaiResponse)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`OpenAI-Gemini request completed in ${duration}ms`);
|
||||
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`OpenAI-Gemini request completed in ${duration}ms`)
|
||||
} catch (error) {
|
||||
logger.error('OpenAI-Gemini request error:', error);
|
||||
|
||||
logger.error('OpenAI-Gemini request error:', error)
|
||||
|
||||
// 处理速率限制
|
||||
if (error.status === 429) {
|
||||
if (req.apiKey && account && accountSelection) {
|
||||
await unifiedGeminiScheduler.markAccountRateLimited(account.id, 'gemini', sessionHash);
|
||||
await unifiedGeminiScheduler.markAccountRateLimited(account.id, 'gemini', sessionHash)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 返回 OpenAI 格式的错误响应
|
||||
const status = error.status || 500;
|
||||
const status = error.status || 500
|
||||
const errorResponse = {
|
||||
error: error.error || {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
};
|
||||
|
||||
res.status(status).json(errorResponse);
|
||||
}
|
||||
|
||||
res.status(status).json(errorResponse)
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// OpenAI 兼容的模型列表端点
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
return res.status(403).json({
|
||||
@@ -581,23 +616,27 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 选择账户获取模型列表
|
||||
let account = null;
|
||||
let account = null
|
||||
try {
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(apiKeyData, null, null);
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId);
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
null,
|
||||
null
|
||||
)
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error);
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error)
|
||||
}
|
||||
|
||||
let models = [];
|
||||
|
||||
|
||||
let models = []
|
||||
|
||||
if (account) {
|
||||
// 获取实际的模型列表
|
||||
models = await getAvailableModels(account.accessToken, account.proxy);
|
||||
models = await getAvailableModels(account.accessToken, account.proxy)
|
||||
} else {
|
||||
// 返回默认模型列表
|
||||
models = [
|
||||
@@ -607,37 +646,37 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
owned_by: 'google'
|
||||
}
|
||||
];
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// 如果启用了模型限制,过滤模型列表
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels.length > 0) {
|
||||
models = models.filter(model => apiKeyData.restrictedModels.includes(model.id));
|
||||
models = models.filter((model) => apiKeyData.restrictedModels.includes(model.id))
|
||||
}
|
||||
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get OpenAI-Gemini models:', error);
|
||||
logger.error('Failed to get OpenAI-Gemini models:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// OpenAI 兼容的模型详情端点
|
||||
router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const modelId = req.params.model;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
const modelId = req.params.model
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
return res.status(403).json({
|
||||
@@ -646,9 +685,9 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(modelId)) {
|
||||
@@ -658,10 +697,10 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_found'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 返回模型信息
|
||||
res.json({
|
||||
id: modelId,
|
||||
@@ -671,18 +710,18 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
permission: [],
|
||||
root: modelId,
|
||||
parent: null
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get model details:', error);
|
||||
logger.error('Failed to get model details:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve model details',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,158 +1,157 @@
|
||||
const express = require('express');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const crypto = require('crypto');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const express = require('express')
|
||||
const bcrypt = require('bcryptjs')
|
||||
const crypto = require('crypto')
|
||||
const path = require('path')
|
||||
const fs = require('fs')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
|
||||
const router = express.Router();
|
||||
const router = express.Router()
|
||||
|
||||
// 🏠 服务静态文件
|
||||
router.use('/assets', express.static(path.join(__dirname, '../../web/assets')));
|
||||
router.use('/assets', express.static(path.join(__dirname, '../../web/assets')))
|
||||
|
||||
// 🌐 页面路由重定向到新版 admin-spa
|
||||
router.get('/', (req, res) => {
|
||||
res.redirect(301, '/admin-next/api-stats');
|
||||
});
|
||||
res.redirect(301, '/admin-next/api-stats')
|
||||
})
|
||||
|
||||
// 🔐 管理员登录
|
||||
router.post('/auth/login', async (req, res) => {
|
||||
try {
|
||||
const { username, password } = req.body;
|
||||
const { username, password } = req.body
|
||||
|
||||
if (!username || !password) {
|
||||
return res.status(400).json({
|
||||
error: 'Missing credentials',
|
||||
message: 'Username and password are required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 从Redis获取管理员信息
|
||||
let adminData = await redis.getSession('admin_credentials');
|
||||
|
||||
let adminData = await redis.getSession('admin_credentials')
|
||||
|
||||
// 如果Redis中没有管理员凭据,尝试从init.json重新加载
|
||||
if (!adminData || Object.keys(adminData).length === 0) {
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json');
|
||||
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json')
|
||||
|
||||
if (fs.existsSync(initFilePath)) {
|
||||
try {
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'));
|
||||
const saltRounds = 10;
|
||||
const passwordHash = await bcrypt.hash(initData.adminPassword, saltRounds);
|
||||
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'))
|
||||
const saltRounds = 10
|
||||
const passwordHash = await bcrypt.hash(initData.adminPassword, saltRounds)
|
||||
|
||||
adminData = {
|
||||
username: initData.adminUsername,
|
||||
passwordHash: passwordHash,
|
||||
passwordHash,
|
||||
createdAt: initData.initializedAt || new Date().toISOString(),
|
||||
lastLogin: null,
|
||||
updatedAt: initData.updatedAt || null
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 重新存储到Redis,不设置过期时间
|
||||
await redis.getClient().hset('session:admin_credentials', adminData);
|
||||
|
||||
logger.info('✅ Admin credentials reloaded from init.json');
|
||||
await redis.getClient().hset('session:admin_credentials', adminData)
|
||||
|
||||
logger.info('✅ Admin credentials reloaded from init.json')
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to reload admin credentials:', error);
|
||||
logger.error('❌ Failed to reload admin credentials:', error)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid credentials',
|
||||
message: 'Invalid username or password'
|
||||
});
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid credentials',
|
||||
message: 'Invalid username or password'
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 验证用户名和密码
|
||||
const isValidUsername = adminData.username === username;
|
||||
const isValidPassword = await bcrypt.compare(password, adminData.passwordHash);
|
||||
const isValidUsername = adminData.username === username
|
||||
const isValidPassword = await bcrypt.compare(password, adminData.passwordHash)
|
||||
|
||||
if (!isValidUsername || !isValidPassword) {
|
||||
logger.security(`🔒 Failed login attempt for username: ${username}`);
|
||||
logger.security(`🔒 Failed login attempt for username: ${username}`)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid credentials',
|
||||
message: 'Invalid username or password'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 生成会话token
|
||||
const sessionId = crypto.randomBytes(32).toString('hex');
|
||||
|
||||
const sessionId = crypto.randomBytes(32).toString('hex')
|
||||
|
||||
// 存储会话
|
||||
const sessionData = {
|
||||
username: adminData.username,
|
||||
loginTime: new Date().toISOString(),
|
||||
lastActivity: new Date().toISOString()
|
||||
};
|
||||
|
||||
await redis.setSession(sessionId, sessionData, config.security.adminSessionTimeout);
|
||||
|
||||
}
|
||||
|
||||
await redis.setSession(sessionId, sessionData, config.security.adminSessionTimeout)
|
||||
|
||||
// 不再更新 Redis 中的最后登录时间,因为 Redis 只是缓存
|
||||
// init.json 是唯一真实数据源
|
||||
|
||||
logger.success(`🔐 Admin login successful: ${username}`);
|
||||
logger.success(`🔐 Admin login successful: ${username}`)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
token: sessionId,
|
||||
expiresIn: config.security.adminSessionTimeout,
|
||||
username: adminData.username // 返回真实用户名
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Login error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Login error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Login failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🚪 管理员登出
|
||||
router.post('/auth/logout', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (token) {
|
||||
await redis.deleteSession(token);
|
||||
logger.success('🚪 Admin logout successful');
|
||||
await redis.deleteSession(token)
|
||||
logger.success('🚪 Admin logout successful')
|
||||
}
|
||||
|
||||
res.json({ success: true, message: 'Logout successful' });
|
||||
return res.json({ success: true, message: 'Logout successful' })
|
||||
} catch (error) {
|
||||
logger.error('❌ Logout error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Logout error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Logout failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🔑 修改账户信息
|
||||
router.post('/auth/change-password', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (!token) {
|
||||
return res.status(401).json({
|
||||
error: 'No token provided',
|
||||
message: 'Authentication required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const { newUsername, currentPassword, newPassword } = req.body;
|
||||
const { newUsername, currentPassword, newPassword } = req.body
|
||||
|
||||
if (!currentPassword || !newPassword) {
|
||||
return res.status(400).json({
|
||||
error: 'Missing required fields',
|
||||
message: 'Current password and new password are required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证新密码长度
|
||||
@@ -160,189 +159,186 @@ router.post('/auth/change-password', async (req, res) => {
|
||||
return res.status(400).json({
|
||||
error: 'Password too short',
|
||||
message: 'New password must be at least 8 characters long'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取当前会话
|
||||
const sessionData = await redis.getSession(token);
|
||||
const sessionData = await redis.getSession(token)
|
||||
if (!sessionData) {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid token',
|
||||
message: 'Session expired or invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取当前管理员信息
|
||||
const adminData = await redis.getSession('admin_credentials');
|
||||
const adminData = await redis.getSession('admin_credentials')
|
||||
if (!adminData) {
|
||||
return res.status(500).json({
|
||||
error: 'Admin data not found',
|
||||
message: 'Administrator credentials not found'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证当前密码
|
||||
const isValidPassword = await bcrypt.compare(currentPassword, adminData.passwordHash);
|
||||
const isValidPassword = await bcrypt.compare(currentPassword, adminData.passwordHash)
|
||||
if (!isValidPassword) {
|
||||
logger.security(`🔒 Invalid current password attempt for user: ${sessionData.username}`);
|
||||
logger.security(`🔒 Invalid current password attempt for user: ${sessionData.username}`)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid current password',
|
||||
message: 'Current password is incorrect'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 准备更新的数据
|
||||
const updatedUsername = newUsername && newUsername.trim() ? newUsername.trim() : adminData.username;
|
||||
|
||||
const updatedUsername =
|
||||
newUsername && newUsername.trim() ? newUsername.trim() : adminData.username
|
||||
|
||||
// 先更新 init.json(唯一真实数据源)
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json');
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json')
|
||||
if (!fs.existsSync(initFilePath)) {
|
||||
return res.status(500).json({
|
||||
error: 'Configuration file not found',
|
||||
message: 'init.json file is missing'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'));
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'))
|
||||
// const oldData = { ...initData }; // 备份旧数据
|
||||
|
||||
|
||||
// 更新 init.json
|
||||
initData.adminUsername = updatedUsername;
|
||||
initData.adminPassword = newPassword; // 保存明文密码到init.json
|
||||
initData.updatedAt = new Date().toISOString();
|
||||
|
||||
initData.adminUsername = updatedUsername
|
||||
initData.adminPassword = newPassword // 保存明文密码到init.json
|
||||
initData.updatedAt = new Date().toISOString()
|
||||
|
||||
// 先写入文件(如果失败则不会影响 Redis)
|
||||
fs.writeFileSync(initFilePath, JSON.stringify(initData, null, 2));
|
||||
|
||||
fs.writeFileSync(initFilePath, JSON.stringify(initData, null, 2))
|
||||
|
||||
// 文件写入成功后,更新 Redis 缓存
|
||||
const saltRounds = 10;
|
||||
const newPasswordHash = await bcrypt.hash(newPassword, saltRounds);
|
||||
|
||||
const saltRounds = 10
|
||||
const newPasswordHash = await bcrypt.hash(newPassword, saltRounds)
|
||||
|
||||
const updatedAdminData = {
|
||||
username: updatedUsername,
|
||||
passwordHash: newPasswordHash,
|
||||
createdAt: adminData.createdAt,
|
||||
lastLogin: adminData.lastLogin,
|
||||
updatedAt: new Date().toISOString()
|
||||
};
|
||||
|
||||
await redis.setSession('admin_credentials', updatedAdminData);
|
||||
|
||||
}
|
||||
|
||||
await redis.setSession('admin_credentials', updatedAdminData)
|
||||
} catch (fileError) {
|
||||
logger.error('❌ Failed to update init.json:', fileError);
|
||||
logger.error('❌ Failed to update init.json:', fileError)
|
||||
return res.status(500).json({
|
||||
error: 'Update failed',
|
||||
message: 'Failed to update configuration file'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 清除当前会话(强制用户重新登录)
|
||||
await redis.deleteSession(token);
|
||||
await redis.deleteSession(token)
|
||||
|
||||
logger.success(`🔐 Admin password changed successfully for user: ${updatedUsername}`);
|
||||
logger.success(`🔐 Admin password changed successfully for user: ${updatedUsername}`)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
message: 'Password changed successfully. Please login again.',
|
||||
newUsername: updatedUsername
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Change password error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Change password error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Change password failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 👤 获取当前用户信息
|
||||
router.get('/auth/user', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (!token) {
|
||||
return res.status(401).json({
|
||||
error: 'No token provided',
|
||||
message: 'Authentication required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取当前会话
|
||||
const sessionData = await redis.getSession(token);
|
||||
const sessionData = await redis.getSession(token)
|
||||
if (!sessionData) {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid token',
|
||||
message: 'Session expired or invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取管理员信息
|
||||
const adminData = await redis.getSession('admin_credentials');
|
||||
const adminData = await redis.getSession('admin_credentials')
|
||||
if (!adminData) {
|
||||
return res.status(500).json({
|
||||
error: 'Admin data not found',
|
||||
message: 'Administrator credentials not found'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
user: {
|
||||
username: adminData.username,
|
||||
loginTime: sessionData.loginTime,
|
||||
lastActivity: sessionData.lastActivity
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Get user info error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Get user info error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Get user info failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🔄 刷新token
|
||||
router.post('/auth/refresh', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (!token) {
|
||||
return res.status(401).json({
|
||||
error: 'No token provided',
|
||||
message: 'Authentication required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const sessionData = await redis.getSession(token);
|
||||
|
||||
const sessionData = await redis.getSession(token)
|
||||
|
||||
if (!sessionData) {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid token',
|
||||
message: 'Session expired or invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 更新最后活动时间
|
||||
sessionData.lastActivity = new Date().toISOString();
|
||||
await redis.setSession(token, sessionData, config.security.adminSessionTimeout);
|
||||
sessionData.lastActivity = new Date().toISOString()
|
||||
await redis.setSession(token, sessionData, config.security.adminSessionTimeout)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
token: token,
|
||||
token,
|
||||
expiresIn: config.security.adminSessionTimeout
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Token refresh error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Token refresh error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Token refresh failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
Reference in New Issue
Block a user