mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-22 16:43:35 +00:00
refactor: standardize code formatting and linting configuration
- Replace .eslintrc.js with .eslintrc.cjs for better ES module compatibility - Add .prettierrc configuration for consistent code formatting - Update package.json with new lint and format scripts - Add nodemon.json for development hot reloading configuration - Standardize code formatting across all JavaScript and Vue files - Update web admin SPA with improved linting rules and formatting - Add prettier configuration to web admin SPA 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
586
src/app.js
586
src/app.js
@@ -1,255 +1,268 @@
|
||||
const express = require('express');
|
||||
const cors = require('cors');
|
||||
const helmet = require('helmet');
|
||||
const morgan = require('morgan');
|
||||
const compression = require('compression');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const express = require('express')
|
||||
const cors = require('cors')
|
||||
const helmet = require('helmet')
|
||||
const compression = require('compression')
|
||||
const path = require('path')
|
||||
const fs = require('fs')
|
||||
const bcrypt = require('bcryptjs')
|
||||
|
||||
const config = require('../config/config');
|
||||
const logger = require('./utils/logger');
|
||||
const redis = require('./models/redis');
|
||||
const pricingService = require('./services/pricingService');
|
||||
const config = require('../config/config')
|
||||
const logger = require('./utils/logger')
|
||||
const redis = require('./models/redis')
|
||||
const pricingService = require('./services/pricingService')
|
||||
|
||||
// Import routes
|
||||
const apiRoutes = require('./routes/api');
|
||||
const adminRoutes = require('./routes/admin');
|
||||
const webRoutes = require('./routes/web');
|
||||
const apiStatsRoutes = require('./routes/apiStats');
|
||||
const geminiRoutes = require('./routes/geminiRoutes');
|
||||
const openaiGeminiRoutes = require('./routes/openaiGeminiRoutes');
|
||||
const openaiClaudeRoutes = require('./routes/openaiClaudeRoutes');
|
||||
const apiRoutes = require('./routes/api')
|
||||
const adminRoutes = require('./routes/admin')
|
||||
const webRoutes = require('./routes/web')
|
||||
const apiStatsRoutes = require('./routes/apiStats')
|
||||
const geminiRoutes = require('./routes/geminiRoutes')
|
||||
const openaiGeminiRoutes = require('./routes/openaiGeminiRoutes')
|
||||
const openaiClaudeRoutes = require('./routes/openaiClaudeRoutes')
|
||||
|
||||
// Import middleware
|
||||
const {
|
||||
corsMiddleware,
|
||||
requestLogger,
|
||||
securityMiddleware,
|
||||
const {
|
||||
corsMiddleware,
|
||||
requestLogger,
|
||||
securityMiddleware,
|
||||
errorHandler,
|
||||
globalRateLimit,
|
||||
requestSizeLimit
|
||||
} = require('./middleware/auth');
|
||||
} = require('./middleware/auth')
|
||||
|
||||
class Application {
|
||||
constructor() {
|
||||
this.app = express();
|
||||
this.server = null;
|
||||
this.app = express()
|
||||
this.server = null
|
||||
}
|
||||
|
||||
async initialize() {
|
||||
try {
|
||||
// 🔗 连接Redis
|
||||
logger.info('🔄 Connecting to Redis...');
|
||||
await redis.connect();
|
||||
logger.success('✅ Redis connected successfully');
|
||||
|
||||
logger.info('🔄 Connecting to Redis...')
|
||||
await redis.connect()
|
||||
logger.success('✅ Redis connected successfully')
|
||||
|
||||
// 💰 初始化价格服务
|
||||
logger.info('🔄 Initializing pricing service...');
|
||||
await pricingService.initialize();
|
||||
|
||||
logger.info('🔄 Initializing pricing service...')
|
||||
await pricingService.initialize()
|
||||
|
||||
// 🔧 初始化管理员凭据
|
||||
logger.info('🔄 Initializing admin credentials...');
|
||||
await this.initializeAdmin();
|
||||
|
||||
logger.info('🔄 Initializing admin credentials...')
|
||||
await this.initializeAdmin()
|
||||
|
||||
// 💰 初始化费用数据
|
||||
logger.info('💰 Checking cost data initialization...');
|
||||
const costInitService = require('./services/costInitService');
|
||||
const needsInit = await costInitService.needsInitialization();
|
||||
logger.info('💰 Checking cost data initialization...')
|
||||
const costInitService = require('./services/costInitService')
|
||||
const needsInit = await costInitService.needsInitialization()
|
||||
if (needsInit) {
|
||||
logger.info('💰 Initializing cost data for all API Keys...');
|
||||
const result = await costInitService.initializeAllCosts();
|
||||
logger.info(`💰 Cost initialization completed: ${result.processed} processed, ${result.errors} errors`);
|
||||
logger.info('💰 Initializing cost data for all API Keys...')
|
||||
const result = await costInitService.initializeAllCosts()
|
||||
logger.info(
|
||||
`💰 Cost initialization completed: ${result.processed} processed, ${result.errors} errors`
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// 🕐 初始化Claude账户会话窗口
|
||||
logger.info('🕐 Initializing Claude account session windows...');
|
||||
const claudeAccountService = require('./services/claudeAccountService');
|
||||
await claudeAccountService.initializeSessionWindows();
|
||||
|
||||
logger.info('🕐 Initializing Claude account session windows...')
|
||||
const claudeAccountService = require('./services/claudeAccountService')
|
||||
await claudeAccountService.initializeSessionWindows()
|
||||
|
||||
// 超早期拦截 /admin-next/ 请求 - 在所有中间件之前
|
||||
this.app.use((req, res, next) => {
|
||||
if (req.path === '/admin-next/' && req.method === 'GET') {
|
||||
logger.warn(`🚨 INTERCEPTING /admin-next/ request at the very beginning!`);
|
||||
const adminSpaPath = path.join(__dirname, '..', 'web', 'admin-spa', 'dist');
|
||||
const indexPath = path.join(adminSpaPath, 'index.html');
|
||||
|
||||
logger.warn('🚨 INTERCEPTING /admin-next/ request at the very beginning!')
|
||||
const adminSpaPath = path.join(__dirname, '..', 'web', 'admin-spa', 'dist')
|
||||
const indexPath = path.join(adminSpaPath, 'index.html')
|
||||
|
||||
if (fs.existsSync(indexPath)) {
|
||||
res.setHeader('Cache-Control', 'no-cache, no-store, must-revalidate');
|
||||
return res.sendFile(indexPath);
|
||||
res.setHeader('Cache-Control', 'no-cache, no-store, must-revalidate')
|
||||
return res.sendFile(indexPath)
|
||||
} else {
|
||||
logger.error('❌ index.html not found at:', indexPath);
|
||||
return res.status(404).send('index.html not found');
|
||||
logger.error('❌ index.html not found at:', indexPath)
|
||||
return res.status(404).send('index.html not found')
|
||||
}
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
next()
|
||||
})
|
||||
|
||||
// 🛡️ 安全中间件
|
||||
this.app.use(helmet({
|
||||
contentSecurityPolicy: false, // 允许内联样式和脚本
|
||||
crossOriginEmbedderPolicy: false
|
||||
}));
|
||||
|
||||
this.app.use(
|
||||
helmet({
|
||||
contentSecurityPolicy: false, // 允许内联样式和脚本
|
||||
crossOriginEmbedderPolicy: false
|
||||
})
|
||||
)
|
||||
|
||||
// 🌐 CORS
|
||||
if (config.web.enableCors) {
|
||||
this.app.use(cors());
|
||||
this.app.use(cors())
|
||||
} else {
|
||||
this.app.use(corsMiddleware);
|
||||
this.app.use(corsMiddleware)
|
||||
}
|
||||
|
||||
|
||||
// 📦 压缩 - 排除流式响应(SSE)
|
||||
this.app.use(compression({
|
||||
filter: (req, res) => {
|
||||
// 不压缩 Server-Sent Events
|
||||
if (res.getHeader('Content-Type') === 'text/event-stream') {
|
||||
return false;
|
||||
this.app.use(
|
||||
compression({
|
||||
filter: (req, res) => {
|
||||
// 不压缩 Server-Sent Events
|
||||
if (res.getHeader('Content-Type') === 'text/event-stream') {
|
||||
return false
|
||||
}
|
||||
// 使用默认的压缩判断
|
||||
return compression.filter(req, res)
|
||||
}
|
||||
// 使用默认的压缩判断
|
||||
return compression.filter(req, res);
|
||||
}
|
||||
}));
|
||||
|
||||
})
|
||||
)
|
||||
|
||||
// 🚦 全局速率限制(仅在生产环境启用)
|
||||
if (process.env.NODE_ENV === 'production') {
|
||||
this.app.use(globalRateLimit);
|
||||
this.app.use(globalRateLimit)
|
||||
}
|
||||
|
||||
|
||||
// 📏 请求大小限制
|
||||
this.app.use(requestSizeLimit);
|
||||
|
||||
this.app.use(requestSizeLimit)
|
||||
|
||||
// 📝 请求日志(使用自定义logger而不是morgan)
|
||||
this.app.use(requestLogger);
|
||||
|
||||
this.app.use(requestLogger)
|
||||
|
||||
// 🔧 基础中间件
|
||||
this.app.use(express.json({
|
||||
limit: '10mb',
|
||||
verify: (req, res, buf, encoding) => {
|
||||
// 验证JSON格式
|
||||
if (buf && buf.length && !buf.toString(encoding || 'utf8').trim()) {
|
||||
throw new Error('Invalid JSON: empty body');
|
||||
this.app.use(
|
||||
express.json({
|
||||
limit: '10mb',
|
||||
verify: (req, res, buf, encoding) => {
|
||||
// 验证JSON格式
|
||||
if (buf && buf.length && !buf.toString(encoding || 'utf8').trim()) {
|
||||
throw new Error('Invalid JSON: empty body')
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
this.app.use(express.urlencoded({ extended: true, limit: '10mb' }));
|
||||
this.app.use(securityMiddleware);
|
||||
|
||||
})
|
||||
)
|
||||
this.app.use(express.urlencoded({ extended: true, limit: '10mb' }))
|
||||
this.app.use(securityMiddleware)
|
||||
|
||||
// 🎯 信任代理
|
||||
if (config.server.trustProxy) {
|
||||
this.app.set('trust proxy', 1);
|
||||
this.app.set('trust proxy', 1)
|
||||
}
|
||||
|
||||
// 调试中间件 - 拦截所有 /admin-next 请求
|
||||
this.app.use((req, res, next) => {
|
||||
if (req.path.startsWith('/admin-next')) {
|
||||
logger.info(`🔍 DEBUG: Incoming request - method: ${req.method}, path: ${req.path}, originalUrl: ${req.originalUrl}`);
|
||||
logger.info(
|
||||
`🔍 DEBUG: Incoming request - method: ${req.method}, path: ${req.path}, originalUrl: ${req.originalUrl}`
|
||||
)
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
next()
|
||||
})
|
||||
|
||||
// 🎨 新版管理界面静态文件服务(必须在其他路由之前)
|
||||
const adminSpaPath = path.join(__dirname, '..', 'web', 'admin-spa', 'dist');
|
||||
const adminSpaPath = path.join(__dirname, '..', 'web', 'admin-spa', 'dist')
|
||||
if (fs.existsSync(adminSpaPath)) {
|
||||
// 处理不带斜杠的路径,重定向到带斜杠的路径
|
||||
this.app.get('/admin-next', (req, res) => {
|
||||
res.redirect(301, '/admin-next/');
|
||||
});
|
||||
|
||||
res.redirect(301, '/admin-next/')
|
||||
})
|
||||
|
||||
// 使用 all 方法确保捕获所有 HTTP 方法
|
||||
this.app.all('/admin-next/', (req, res) => {
|
||||
logger.info('🎯 HIT: /admin-next/ route handler triggered!');
|
||||
logger.info(`Method: ${req.method}, Path: ${req.path}, URL: ${req.url}`);
|
||||
|
||||
logger.info('🎯 HIT: /admin-next/ route handler triggered!')
|
||||
logger.info(`Method: ${req.method}, Path: ${req.path}, URL: ${req.url}`)
|
||||
|
||||
if (req.method !== 'GET' && req.method !== 'HEAD') {
|
||||
return res.status(405).send('Method Not Allowed');
|
||||
return res.status(405).send('Method Not Allowed')
|
||||
}
|
||||
|
||||
res.setHeader('Cache-Control', 'no-cache, no-store, must-revalidate');
|
||||
res.sendFile(path.join(adminSpaPath, 'index.html'));
|
||||
});
|
||||
|
||||
|
||||
res.setHeader('Cache-Control', 'no-cache, no-store, must-revalidate')
|
||||
res.sendFile(path.join(adminSpaPath, 'index.html'))
|
||||
})
|
||||
|
||||
// 处理所有其他 /admin-next/* 路径(但排除根路径)
|
||||
this.app.get('/admin-next/*', (req, res) => {
|
||||
// 如果是根路径,跳过(应该由上面的路由处理)
|
||||
if (req.path === '/admin-next/') {
|
||||
logger.error('❌ ERROR: /admin-next/ should not reach here!');
|
||||
return res.status(500).send('Route configuration error');
|
||||
logger.error('❌ ERROR: /admin-next/ should not reach here!')
|
||||
return res.status(500).send('Route configuration error')
|
||||
}
|
||||
|
||||
const requestPath = req.path.replace('/admin-next/', '');
|
||||
|
||||
|
||||
const requestPath = req.path.replace('/admin-next/', '')
|
||||
|
||||
// 安全检查
|
||||
if (requestPath.includes('..') || requestPath.includes('//') || requestPath.includes('\\')) {
|
||||
return res.status(400).json({ error: 'Invalid path' });
|
||||
if (
|
||||
requestPath.includes('..') ||
|
||||
requestPath.includes('//') ||
|
||||
requestPath.includes('\\')
|
||||
) {
|
||||
return res.status(400).json({ error: 'Invalid path' })
|
||||
}
|
||||
|
||||
|
||||
// 检查是否为静态资源
|
||||
const filePath = path.join(adminSpaPath, requestPath);
|
||||
|
||||
const filePath = path.join(adminSpaPath, requestPath)
|
||||
|
||||
// 如果文件存在且是静态资源
|
||||
if (fs.existsSync(filePath) && fs.statSync(filePath).isFile()) {
|
||||
// 设置缓存头
|
||||
if (filePath.endsWith('.js') || filePath.endsWith('.css')) {
|
||||
res.setHeader('Cache-Control', 'public, max-age=31536000, immutable');
|
||||
res.setHeader('Cache-Control', 'public, max-age=31536000, immutable')
|
||||
} else if (filePath.endsWith('.html')) {
|
||||
res.setHeader('Cache-Control', 'no-cache, no-store, must-revalidate');
|
||||
res.setHeader('Cache-Control', 'no-cache, no-store, must-revalidate')
|
||||
}
|
||||
return res.sendFile(filePath);
|
||||
return res.sendFile(filePath)
|
||||
}
|
||||
|
||||
|
||||
// 如果是静态资源但文件不存在
|
||||
if (requestPath.match(/\.(js|css|png|jpg|jpeg|gif|svg|ico|woff|woff2|ttf)$/i)) {
|
||||
return res.status(404).send('Not found');
|
||||
return res.status(404).send('Not found')
|
||||
}
|
||||
|
||||
|
||||
// 其他所有路径返回 index.html(SPA 路由)
|
||||
res.sendFile(path.join(adminSpaPath, 'index.html'));
|
||||
});
|
||||
|
||||
logger.info('✅ Admin SPA (next) static files mounted at /admin-next/');
|
||||
res.sendFile(path.join(adminSpaPath, 'index.html'))
|
||||
})
|
||||
|
||||
logger.info('✅ Admin SPA (next) static files mounted at /admin-next/')
|
||||
} else {
|
||||
logger.warn('⚠️ Admin SPA dist directory not found, skipping /admin-next route');
|
||||
logger.warn('⚠️ Admin SPA dist directory not found, skipping /admin-next route')
|
||||
}
|
||||
|
||||
// 🛣️ 路由
|
||||
this.app.use('/api', apiRoutes);
|
||||
this.app.use('/claude', apiRoutes); // /claude 路由别名,与 /api 功能相同
|
||||
this.app.use('/admin', adminRoutes);
|
||||
this.app.use('/api', apiRoutes)
|
||||
this.app.use('/claude', apiRoutes) // /claude 路由别名,与 /api 功能相同
|
||||
this.app.use('/admin', adminRoutes)
|
||||
// 使用 web 路由(包含 auth 和页面重定向)
|
||||
this.app.use('/web', webRoutes);
|
||||
this.app.use('/apiStats', apiStatsRoutes);
|
||||
this.app.use('/gemini', geminiRoutes);
|
||||
this.app.use('/openai/gemini', openaiGeminiRoutes);
|
||||
this.app.use('/openai/claude', openaiClaudeRoutes);
|
||||
|
||||
this.app.use('/web', webRoutes)
|
||||
this.app.use('/apiStats', apiStatsRoutes)
|
||||
this.app.use('/gemini', geminiRoutes)
|
||||
this.app.use('/openai/gemini', openaiGeminiRoutes)
|
||||
this.app.use('/openai/claude', openaiClaudeRoutes)
|
||||
|
||||
// 🏠 根路径重定向到新版管理界面
|
||||
this.app.get('/', (req, res) => {
|
||||
res.redirect('/admin-next/api-stats');
|
||||
});
|
||||
|
||||
res.redirect('/admin-next/api-stats')
|
||||
})
|
||||
|
||||
// 🏥 增强的健康检查端点
|
||||
this.app.get('/health', async (req, res) => {
|
||||
try {
|
||||
const timer = logger.timer('health-check');
|
||||
|
||||
const timer = logger.timer('health-check')
|
||||
|
||||
// 检查各个组件健康状态
|
||||
const [redisHealth, loggerHealth] = await Promise.all([
|
||||
this.checkRedisHealth(),
|
||||
this.checkLoggerHealth()
|
||||
]);
|
||||
|
||||
const memory = process.memoryUsage();
|
||||
|
||||
])
|
||||
|
||||
const memory = process.memoryUsage()
|
||||
|
||||
// 获取版本号:优先使用环境变量,其次VERSION文件,再次package.json,最后使用默认值
|
||||
let version = process.env.APP_VERSION || process.env.VERSION;
|
||||
let version = process.env.APP_VERSION || process.env.VERSION
|
||||
if (!version) {
|
||||
try {
|
||||
// 尝试从VERSION文件读取
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const versionFile = path.join(__dirname, '..', 'VERSION');
|
||||
const fs = require('fs')
|
||||
const path = require('path')
|
||||
const versionFile = path.join(__dirname, '..', 'VERSION')
|
||||
if (fs.existsSync(versionFile)) {
|
||||
version = fs.readFileSync(versionFile, 'utf8').trim();
|
||||
version = fs.readFileSync(versionFile, 'utf8').trim()
|
||||
}
|
||||
} catch (error) {
|
||||
// 忽略错误,继续尝试其他方式
|
||||
@@ -257,13 +270,13 @@ class Application {
|
||||
}
|
||||
if (!version) {
|
||||
try {
|
||||
const packageJson = require('../package.json');
|
||||
version = packageJson.version;
|
||||
const packageJson = require('../package.json')
|
||||
version = packageJson.version
|
||||
} catch (error) {
|
||||
version = '1.0.0';
|
||||
version = '1.0.0'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const health = {
|
||||
status: 'healthy',
|
||||
service: 'claude-relay-service',
|
||||
@@ -280,75 +293,74 @@ class Application {
|
||||
logger: loggerHealth
|
||||
},
|
||||
stats: logger.getStats()
|
||||
};
|
||||
|
||||
timer.end('completed');
|
||||
res.json(health);
|
||||
}
|
||||
|
||||
timer.end('completed')
|
||||
res.json(health)
|
||||
} catch (error) {
|
||||
logger.error('❌ Health check failed:', { error: error.message, stack: error.stack });
|
||||
logger.error('❌ Health check failed:', { error: error.message, stack: error.stack })
|
||||
res.status(503).json({
|
||||
status: 'unhealthy',
|
||||
error: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 📊 指标端点
|
||||
this.app.get('/metrics', async (req, res) => {
|
||||
try {
|
||||
const stats = await redis.getSystemStats();
|
||||
const stats = await redis.getSystemStats()
|
||||
const metrics = {
|
||||
...stats,
|
||||
uptime: process.uptime(),
|
||||
memory: process.memoryUsage(),
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
|
||||
res.json(metrics);
|
||||
}
|
||||
|
||||
res.json(metrics)
|
||||
} catch (error) {
|
||||
logger.error('❌ Metrics collection failed:', error);
|
||||
res.status(500).json({ error: 'Failed to collect metrics' });
|
||||
logger.error('❌ Metrics collection failed:', error)
|
||||
res.status(500).json({ error: 'Failed to collect metrics' })
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 🚫 404 处理
|
||||
this.app.use('*', (req, res) => {
|
||||
res.status(404).json({
|
||||
error: 'Not Found',
|
||||
message: `Route ${req.originalUrl} not found`,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
});
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
// 🚨 错误处理
|
||||
this.app.use(errorHandler);
|
||||
|
||||
logger.success('✅ Application initialized successfully');
|
||||
|
||||
this.app.use(errorHandler)
|
||||
|
||||
logger.success('✅ Application initialized successfully')
|
||||
} catch (error) {
|
||||
logger.error('💥 Application initialization failed:', error);
|
||||
throw error;
|
||||
logger.error('💥 Application initialization failed:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔧 初始化管理员凭据(总是从 init.json 加载,确保数据一致性)
|
||||
async initializeAdmin() {
|
||||
try {
|
||||
const initFilePath = path.join(__dirname, '..', 'data', 'init.json');
|
||||
|
||||
const initFilePath = path.join(__dirname, '..', 'data', 'init.json')
|
||||
|
||||
if (!fs.existsSync(initFilePath)) {
|
||||
logger.warn('⚠️ No admin credentials found. Please run npm run setup first.');
|
||||
return;
|
||||
logger.warn('⚠️ No admin credentials found. Please run npm run setup first.')
|
||||
return
|
||||
}
|
||||
|
||||
// 从 init.json 读取管理员凭据(作为唯一真实数据源)
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'));
|
||||
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'))
|
||||
|
||||
// 将明文密码哈希化
|
||||
const saltRounds = 10;
|
||||
const passwordHash = await bcrypt.hash(initData.adminPassword, saltRounds);
|
||||
|
||||
const saltRounds = 10
|
||||
const passwordHash = await bcrypt.hash(initData.adminPassword, saltRounds)
|
||||
|
||||
// 存储到Redis(每次启动都覆盖,确保与 init.json 同步)
|
||||
const adminCredentials = {
|
||||
username: initData.adminUsername,
|
||||
@@ -356,84 +368,90 @@ class Application {
|
||||
createdAt: initData.initializedAt || new Date().toISOString(),
|
||||
lastLogin: null,
|
||||
updatedAt: initData.updatedAt || null
|
||||
};
|
||||
|
||||
await redis.setSession('admin_credentials', adminCredentials);
|
||||
|
||||
logger.success('✅ Admin credentials loaded from init.json (single source of truth)');
|
||||
logger.info(`📋 Admin username: ${adminCredentials.username}`);
|
||||
|
||||
}
|
||||
|
||||
await redis.setSession('admin_credentials', adminCredentials)
|
||||
|
||||
logger.success('✅ Admin credentials loaded from init.json (single source of truth)')
|
||||
logger.info(`📋 Admin username: ${adminCredentials.username}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to initialize admin credentials:', { error: error.message, stack: error.stack });
|
||||
throw error;
|
||||
logger.error('❌ Failed to initialize admin credentials:', {
|
||||
error: error.message,
|
||||
stack: error.stack
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 Redis健康检查
|
||||
async checkRedisHealth() {
|
||||
try {
|
||||
const start = Date.now();
|
||||
await redis.getClient().ping();
|
||||
const latency = Date.now() - start;
|
||||
|
||||
const start = Date.now()
|
||||
await redis.getClient().ping()
|
||||
const latency = Date.now() - start
|
||||
|
||||
return {
|
||||
status: 'healthy',
|
||||
connected: redis.isConnected,
|
||||
latency: `${latency}ms`
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
status: 'unhealthy',
|
||||
connected: false,
|
||||
error: error.message
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 📝 Logger健康检查
|
||||
async checkLoggerHealth() {
|
||||
try {
|
||||
const health = logger.healthCheck();
|
||||
const health = logger.healthCheck()
|
||||
return {
|
||||
status: health.healthy ? 'healthy' : 'unhealthy',
|
||||
...health
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
status: 'unhealthy',
|
||||
error: error.message
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async start() {
|
||||
try {
|
||||
await this.initialize();
|
||||
|
||||
this.server = this.app.listen(config.server.port, config.server.host, () => {
|
||||
logger.start(`🚀 Claude Relay Service started on ${config.server.host}:${config.server.port}`);
|
||||
logger.info(`🌐 Web interface: http://${config.server.host}:${config.server.port}/admin-next/api-stats`);
|
||||
logger.info(`🔗 API endpoint: http://${config.server.host}:${config.server.port}/api/v1/messages`);
|
||||
logger.info(`⚙️ Admin API: http://${config.server.host}:${config.server.port}/admin`);
|
||||
logger.info(`🏥 Health check: http://${config.server.host}:${config.server.port}/health`);
|
||||
logger.info(`📊 Metrics: http://${config.server.host}:${config.server.port}/metrics`);
|
||||
});
|
||||
await this.initialize()
|
||||
|
||||
const serverTimeout = 600000; // 默认10分钟
|
||||
this.server.timeout = serverTimeout;
|
||||
this.server.keepAliveTimeout = serverTimeout + 5000; // keepAlive 稍长一点
|
||||
logger.info(`⏱️ Server timeout set to ${serverTimeout}ms (${serverTimeout/1000}s)`);
|
||||
|
||||
this.server = this.app.listen(config.server.port, config.server.host, () => {
|
||||
logger.start(
|
||||
`🚀 Claude Relay Service started on ${config.server.host}:${config.server.port}`
|
||||
)
|
||||
logger.info(
|
||||
`🌐 Web interface: http://${config.server.host}:${config.server.port}/admin-next/api-stats`
|
||||
)
|
||||
logger.info(
|
||||
`🔗 API endpoint: http://${config.server.host}:${config.server.port}/api/v1/messages`
|
||||
)
|
||||
logger.info(`⚙️ Admin API: http://${config.server.host}:${config.server.port}/admin`)
|
||||
logger.info(`🏥 Health check: http://${config.server.host}:${config.server.port}/health`)
|
||||
logger.info(`📊 Metrics: http://${config.server.host}:${config.server.port}/metrics`)
|
||||
})
|
||||
|
||||
const serverTimeout = 600000 // 默认10分钟
|
||||
this.server.timeout = serverTimeout
|
||||
this.server.keepAliveTimeout = serverTimeout + 5000 // keepAlive 稍长一点
|
||||
logger.info(`⏱️ Server timeout set to ${serverTimeout}ms (${serverTimeout / 1000}s)`)
|
||||
|
||||
// 🔄 定期清理任务
|
||||
this.startCleanupTasks();
|
||||
|
||||
this.startCleanupTasks()
|
||||
|
||||
// 🛑 优雅关闭
|
||||
this.setupGracefulShutdown();
|
||||
|
||||
this.setupGracefulShutdown()
|
||||
} catch (error) {
|
||||
logger.error('💥 Failed to start server:', error);
|
||||
process.exit(1);
|
||||
logger.error('💥 Failed to start server:', error)
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,87 +459,91 @@ class Application {
|
||||
// 🧹 每小时清理一次过期数据
|
||||
setInterval(async () => {
|
||||
try {
|
||||
logger.info('🧹 Starting scheduled cleanup...');
|
||||
|
||||
const apiKeyService = require('./services/apiKeyService');
|
||||
const claudeAccountService = require('./services/claudeAccountService');
|
||||
|
||||
logger.info('🧹 Starting scheduled cleanup...')
|
||||
|
||||
const apiKeyService = require('./services/apiKeyService')
|
||||
const claudeAccountService = require('./services/claudeAccountService')
|
||||
|
||||
const [expiredKeys, errorAccounts] = await Promise.all([
|
||||
apiKeyService.cleanupExpiredKeys(),
|
||||
claudeAccountService.cleanupErrorAccounts()
|
||||
]);
|
||||
|
||||
await redis.cleanup();
|
||||
|
||||
logger.success(`🧹 Cleanup completed: ${expiredKeys} expired keys, ${errorAccounts} error accounts reset`);
|
||||
} catch (error) {
|
||||
logger.error('❌ Cleanup task failed:', error);
|
||||
}
|
||||
}, config.system.cleanupInterval);
|
||||
])
|
||||
|
||||
logger.info(`🔄 Cleanup tasks scheduled every ${config.system.cleanupInterval / 1000 / 60} minutes`);
|
||||
await redis.cleanup()
|
||||
|
||||
logger.success(
|
||||
`🧹 Cleanup completed: ${expiredKeys} expired keys, ${errorAccounts} error accounts reset`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('❌ Cleanup task failed:', error)
|
||||
}
|
||||
}, config.system.cleanupInterval)
|
||||
|
||||
logger.info(
|
||||
`🔄 Cleanup tasks scheduled every ${config.system.cleanupInterval / 1000 / 60} minutes`
|
||||
)
|
||||
}
|
||||
|
||||
setupGracefulShutdown() {
|
||||
const shutdown = async (signal) => {
|
||||
logger.info(`🛑 Received ${signal}, starting graceful shutdown...`);
|
||||
|
||||
logger.info(`🛑 Received ${signal}, starting graceful shutdown...`)
|
||||
|
||||
if (this.server) {
|
||||
this.server.close(async () => {
|
||||
logger.info('🚪 HTTP server closed');
|
||||
|
||||
logger.info('🚪 HTTP server closed')
|
||||
|
||||
// 清理 pricing service 的文件监听器
|
||||
try {
|
||||
pricingService.cleanup();
|
||||
logger.info('💰 Pricing service cleaned up');
|
||||
pricingService.cleanup()
|
||||
logger.info('💰 Pricing service cleaned up')
|
||||
} catch (error) {
|
||||
logger.error('❌ Error cleaning up pricing service:', error);
|
||||
logger.error('❌ Error cleaning up pricing service:', error)
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
await redis.disconnect();
|
||||
logger.info('👋 Redis disconnected');
|
||||
await redis.disconnect()
|
||||
logger.info('👋 Redis disconnected')
|
||||
} catch (error) {
|
||||
logger.error('❌ Error disconnecting Redis:', error);
|
||||
logger.error('❌ Error disconnecting Redis:', error)
|
||||
}
|
||||
|
||||
logger.success('✅ Graceful shutdown completed');
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
logger.success('✅ Graceful shutdown completed')
|
||||
process.exit(0)
|
||||
})
|
||||
|
||||
// 强制关闭超时
|
||||
setTimeout(() => {
|
||||
logger.warn('⚠️ Forced shutdown due to timeout');
|
||||
process.exit(1);
|
||||
}, 10000);
|
||||
logger.warn('⚠️ Forced shutdown due to timeout')
|
||||
process.exit(1)
|
||||
}, 10000)
|
||||
} else {
|
||||
process.exit(0);
|
||||
process.exit(0)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
process.on('SIGTERM', () => shutdown('SIGTERM'))
|
||||
process.on('SIGINT', () => shutdown('SIGINT'))
|
||||
|
||||
process.on('SIGTERM', () => shutdown('SIGTERM'));
|
||||
process.on('SIGINT', () => shutdown('SIGINT'));
|
||||
|
||||
// 处理未捕获异常
|
||||
process.on('uncaughtException', (error) => {
|
||||
logger.error('💥 Uncaught exception:', error);
|
||||
shutdown('uncaughtException');
|
||||
});
|
||||
|
||||
logger.error('💥 Uncaught exception:', error)
|
||||
shutdown('uncaughtException')
|
||||
})
|
||||
|
||||
process.on('unhandledRejection', (reason, promise) => {
|
||||
logger.error('💥 Unhandled rejection at:', promise, 'reason:', reason);
|
||||
shutdown('unhandledRejection');
|
||||
});
|
||||
logger.error('💥 Unhandled rejection at:', promise, 'reason:', reason)
|
||||
shutdown('unhandledRejection')
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 启动应用
|
||||
if (require.main === module) {
|
||||
const app = new Application();
|
||||
const app = new Application()
|
||||
app.start().catch((error) => {
|
||||
logger.error('💥 Application startup failed:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
logger.error('💥 Application startup failed:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
|
||||
module.exports = Application;
|
||||
module.exports = Application
|
||||
|
||||
@@ -1,32 +1,35 @@
|
||||
#!/usr/bin/env node
|
||||
|
||||
const costInitService = require('../services/costInitService');
|
||||
const logger = require('../utils/logger');
|
||||
const redis = require('../models/redis');
|
||||
const costInitService = require('../services/costInitService')
|
||||
const logger = require('../utils/logger')
|
||||
const redis = require('../models/redis')
|
||||
|
||||
async function main() {
|
||||
try {
|
||||
// 连接Redis
|
||||
await redis.connect();
|
||||
|
||||
console.log('💰 Starting cost data initialization...\n');
|
||||
|
||||
await redis.connect()
|
||||
|
||||
console.log('💰 Starting cost data initialization...\n')
|
||||
|
||||
// 执行初始化
|
||||
const result = await costInitService.initializeAllCosts();
|
||||
|
||||
console.log('\n✅ Cost initialization completed!');
|
||||
console.log(` Processed: ${result.processed} API Keys`);
|
||||
console.log(` Errors: ${result.errors}`);
|
||||
|
||||
const result = await costInitService.initializeAllCosts()
|
||||
|
||||
console.log('\n✅ Cost initialization completed!')
|
||||
console.log(` Processed: ${result.processed} API Keys`)
|
||||
console.log(` Errors: ${result.errors}`)
|
||||
|
||||
// 断开连接
|
||||
await redis.disconnect();
|
||||
process.exit(0);
|
||||
await redis.disconnect()
|
||||
throw new Error('INIT_COSTS_SUCCESS')
|
||||
} catch (error) {
|
||||
console.error('\n❌ Cost initialization failed:', error.message);
|
||||
logger.error('Cost initialization failed:', error);
|
||||
process.exit(1);
|
||||
if (error.message === 'INIT_COSTS_SUCCESS') {
|
||||
return
|
||||
}
|
||||
console.error('\n❌ Cost initialization failed:', error.message)
|
||||
logger.error('Cost initialization failed:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 运行主函数
|
||||
main();
|
||||
main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1170
src/models/redis.js
1170
src/models/redis.js
File diff suppressed because it is too large
Load Diff
3846
src/routes/admin.js
3846
src/routes/admin.js
File diff suppressed because it is too large
Load Diff
@@ -1,360 +1,482 @@
|
||||
const express = require('express');
|
||||
const claudeRelayService = require('../services/claudeRelayService');
|
||||
const claudeConsoleRelayService = require('../services/claudeConsoleRelayService');
|
||||
const bedrockRelayService = require('../services/bedrockRelayService');
|
||||
const bedrockAccountService = require('../services/bedrockAccountService');
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const logger = require('../utils/logger');
|
||||
const redis = require('../models/redis');
|
||||
const sessionHelper = require('../utils/sessionHelper');
|
||||
const express = require('express')
|
||||
const claudeRelayService = require('../services/claudeRelayService')
|
||||
const claudeConsoleRelayService = require('../services/claudeConsoleRelayService')
|
||||
const bedrockRelayService = require('../services/bedrockRelayService')
|
||||
const bedrockAccountService = require('../services/bedrockAccountService')
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const logger = require('../utils/logger')
|
||||
const redis = require('../models/redis')
|
||||
const sessionHelper = require('../utils/sessionHelper')
|
||||
|
||||
const router = express.Router();
|
||||
const router = express.Router()
|
||||
|
||||
// 🔧 共享的消息处理函数
|
||||
async function handleMessagesRequest(req, res) {
|
||||
try {
|
||||
const startTime = Date.now();
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 严格的输入验证
|
||||
if (!req.body || typeof req.body !== 'object') {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid request',
|
||||
message: 'Request body must be a valid JSON object'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
if (!req.body.messages || !Array.isArray(req.body.messages)) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid request',
|
||||
message: 'Missing or invalid field: messages (must be an array)'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
if (req.body.messages.length === 0) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid request',
|
||||
message: 'Messages array cannot be empty'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 检查是否为流式请求
|
||||
const isStream = req.body.stream === true;
|
||||
|
||||
logger.api(`🚀 Processing ${isStream ? 'stream' : 'non-stream'} request for key: ${req.apiKey.name}`);
|
||||
const isStream = req.body.stream === true
|
||||
|
||||
logger.api(
|
||||
`🚀 Processing ${isStream ? 'stream' : 'non-stream'} request for key: ${req.apiKey.name}`
|
||||
)
|
||||
|
||||
if (isStream) {
|
||||
// 流式响应 - 只使用官方真实usage数据
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('Access-Control-Allow-Origin', '*');
|
||||
res.setHeader('X-Accel-Buffering', 'no'); // 禁用 Nginx 缓冲
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||
res.setHeader('X-Accel-Buffering', 'no') // 禁用 Nginx 缓冲
|
||||
|
||||
// 禁用 Nagle 算法,确保数据立即发送
|
||||
if (res.socket && typeof res.socket.setNoDelay === 'function') {
|
||||
res.socket.setNoDelay(true);
|
||||
res.socket.setNoDelay(true)
|
||||
}
|
||||
|
||||
|
||||
// 流式响应不需要额外处理,中间件已经设置了监听器
|
||||
|
||||
let usageDataCaptured = false;
|
||||
|
||||
|
||||
let usageDataCaptured = false
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
|
||||
const requestedModel = req.body.model
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
|
||||
// 根据账号类型选择对应的转发服务并调用
|
||||
if (accountType === 'claude-official') {
|
||||
// 官方Claude账号使用原有的转发服务(会自己选择账号)
|
||||
await claudeRelayService.relayStreamRequestWithUsageCapture(req.body, req.apiKey, res, req.headers, (usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info('🎯 Usage callback triggered with complete data:', JSON.stringify(usageData, null, 2));
|
||||
|
||||
if (usageData && usageData.input_tokens !== undefined && usageData.output_tokens !== undefined) {
|
||||
const inputTokens = usageData.input_tokens || 0;
|
||||
const outputTokens = usageData.output_tokens || 0;
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0;
|
||||
const model = usageData.model || 'unknown';
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const accountId = usageData.accountId;
|
||||
apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model, accountId).catch(error => {
|
||||
logger.error('❌ Failed to record stream usage:', error);
|
||||
});
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens).catch(error => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error);
|
||||
});
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
await claudeRelayService.relayStreamRequestWithUsageCapture(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
res,
|
||||
req.headers,
|
||||
(usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info(
|
||||
'🎯 Usage callback triggered with complete data:',
|
||||
JSON.stringify(usageData, null, 2)
|
||||
)
|
||||
|
||||
if (
|
||||
usageData &&
|
||||
usageData.input_tokens !== undefined &&
|
||||
usageData.output_tokens !== undefined
|
||||
) {
|
||||
const inputTokens = usageData.input_tokens || 0
|
||||
const outputTokens = usageData.output_tokens || 0
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0
|
||||
const model = usageData.model || 'unknown'
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const { accountId: usageAccountId } = usageData
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
req.apiKey.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
usageAccountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record stream usage:', error)
|
||||
})
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
redis
|
||||
.getClient()
|
||||
.incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error)
|
||||
})
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageDataCaptured = true
|
||||
logger.api(
|
||||
`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`
|
||||
)
|
||||
} else {
|
||||
logger.warn(
|
||||
'⚠️ Usage callback triggered but data is incomplete:',
|
||||
JSON.stringify(usageData)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
usageDataCaptured = true;
|
||||
logger.api(`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`);
|
||||
} else {
|
||||
logger.warn('⚠️ Usage callback triggered but data is incomplete:', JSON.stringify(usageData));
|
||||
}
|
||||
});
|
||||
)
|
||||
} else if (accountType === 'claude-console') {
|
||||
// Claude Console账号使用Console转发服务(需要传递accountId)
|
||||
await claudeConsoleRelayService.relayStreamRequestWithUsageCapture(req.body, req.apiKey, res, req.headers, (usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info('🎯 Usage callback triggered with complete data:', JSON.stringify(usageData, null, 2));
|
||||
|
||||
if (usageData && usageData.input_tokens !== undefined && usageData.output_tokens !== undefined) {
|
||||
const inputTokens = usageData.input_tokens || 0;
|
||||
const outputTokens = usageData.output_tokens || 0;
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0;
|
||||
const model = usageData.model || 'unknown';
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const usageAccountId = usageData.accountId;
|
||||
apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model, usageAccountId).catch(error => {
|
||||
logger.error('❌ Failed to record stream usage:', error);
|
||||
});
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens).catch(error => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error);
|
||||
});
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
await claudeConsoleRelayService.relayStreamRequestWithUsageCapture(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
res,
|
||||
req.headers,
|
||||
(usageData) => {
|
||||
// 回调函数:当检测到完整usage数据时记录真实token使用量
|
||||
logger.info(
|
||||
'🎯 Usage callback triggered with complete data:',
|
||||
JSON.stringify(usageData, null, 2)
|
||||
)
|
||||
|
||||
if (
|
||||
usageData &&
|
||||
usageData.input_tokens !== undefined &&
|
||||
usageData.output_tokens !== undefined
|
||||
) {
|
||||
const inputTokens = usageData.input_tokens || 0
|
||||
const outputTokens = usageData.output_tokens || 0
|
||||
const cacheCreateTokens = usageData.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usageData.cache_read_input_tokens || 0
|
||||
const model = usageData.model || 'unknown'
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const usageAccountId = usageData.accountId
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
req.apiKey.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
usageAccountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record stream usage:', error)
|
||||
})
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
redis
|
||||
.getClient()
|
||||
.incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error)
|
||||
})
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageDataCaptured = true
|
||||
logger.api(
|
||||
`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`
|
||||
)
|
||||
} else {
|
||||
logger.warn(
|
||||
'⚠️ Usage callback triggered but data is incomplete:',
|
||||
JSON.stringify(usageData)
|
||||
)
|
||||
}
|
||||
|
||||
usageDataCaptured = true;
|
||||
logger.api(`📊 Stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`);
|
||||
} else {
|
||||
logger.warn('⚠️ Usage callback triggered but data is incomplete:', JSON.stringify(usageData));
|
||||
}
|
||||
}, accountId);
|
||||
},
|
||||
accountId
|
||||
)
|
||||
} else if (accountType === 'bedrock') {
|
||||
// Bedrock账号使用Bedrock转发服务
|
||||
try {
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId);
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId)
|
||||
if (!bedrockAccountResult.success) {
|
||||
throw new Error('Failed to get Bedrock account details');
|
||||
throw new Error('Failed to get Bedrock account details')
|
||||
}
|
||||
|
||||
const result = await bedrockRelayService.handleStreamRequest(req.body, bedrockAccountResult.data, res);
|
||||
|
||||
const result = await bedrockRelayService.handleStreamRequest(
|
||||
req.body,
|
||||
bedrockAccountResult.data,
|
||||
res
|
||||
)
|
||||
|
||||
// 记录Bedrock使用统计
|
||||
if (result.usage) {
|
||||
const inputTokens = result.usage.input_tokens || 0;
|
||||
const outputTokens = result.usage.output_tokens || 0;
|
||||
|
||||
apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, 0, 0, result.model, accountId).catch(error => {
|
||||
logger.error('❌ Failed to record Bedrock stream usage:', error);
|
||||
});
|
||||
|
||||
const inputTokens = result.usage.input_tokens || 0
|
||||
const outputTokens = result.usage.output_tokens || 0
|
||||
|
||||
apiKeyService
|
||||
.recordUsage(req.apiKey.id, inputTokens, outputTokens, 0, 0, result.model, accountId)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record Bedrock stream usage:', error)
|
||||
})
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens;
|
||||
redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens).catch(error => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error);
|
||||
});
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
const totalTokens = inputTokens + outputTokens
|
||||
redis
|
||||
.getClient()
|
||||
.incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to update rate limit token count:', error)
|
||||
})
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageDataCaptured = true;
|
||||
logger.api(`📊 Bedrock stream usage recorded - Model: ${result.model}, Input: ${inputTokens}, Output: ${outputTokens}, Total: ${inputTokens + outputTokens} tokens`);
|
||||
|
||||
usageDataCaptured = true
|
||||
logger.api(
|
||||
`📊 Bedrock stream usage recorded - Model: ${result.model}, Input: ${inputTokens}, Output: ${outputTokens}, Total: ${inputTokens + outputTokens} tokens`
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Bedrock stream request failed:', error);
|
||||
logger.error('❌ Bedrock stream request failed:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({ error: 'Bedrock service error', message: error.message });
|
||||
return res.status(500).json({ error: 'Bedrock service error', message: error.message })
|
||||
}
|
||||
return;
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 流式请求完成后 - 如果没有捕获到usage数据,记录警告但不进行估算
|
||||
setTimeout(() => {
|
||||
if (!usageDataCaptured) {
|
||||
logger.warn('⚠️ No usage data captured from SSE stream - no statistics recorded (official data only)');
|
||||
logger.warn(
|
||||
'⚠️ No usage data captured from SSE stream - no statistics recorded (official data only)'
|
||||
)
|
||||
}
|
||||
}, 1000); // 1秒后检查
|
||||
}, 1000) // 1秒后检查
|
||||
} else {
|
||||
// 非流式响应 - 只使用官方真实usage数据
|
||||
logger.info('📄 Starting non-streaming request', {
|
||||
apiKeyId: req.apiKey.id,
|
||||
apiKeyName: req.apiKey.name
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
|
||||
const requestedModel = req.body.model
|
||||
const { accountId, accountType } = await unifiedClaudeScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
|
||||
// 根据账号类型选择对应的转发服务
|
||||
let response;
|
||||
logger.debug(`[DEBUG] Request query params: ${JSON.stringify(req.query)}`);
|
||||
logger.debug(`[DEBUG] Request URL: ${req.url}`);
|
||||
logger.debug(`[DEBUG] Request path: ${req.path}`);
|
||||
|
||||
let response
|
||||
logger.debug(`[DEBUG] Request query params: ${JSON.stringify(req.query)}`)
|
||||
logger.debug(`[DEBUG] Request URL: ${req.url}`)
|
||||
logger.debug(`[DEBUG] Request path: ${req.path}`)
|
||||
|
||||
if (accountType === 'claude-official') {
|
||||
// 官方Claude账号使用原有的转发服务
|
||||
response = await claudeRelayService.relayRequest(req.body, req.apiKey, req, res, req.headers);
|
||||
response = await claudeRelayService.relayRequest(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
req,
|
||||
res,
|
||||
req.headers
|
||||
)
|
||||
} else if (accountType === 'claude-console') {
|
||||
// Claude Console账号使用Console转发服务
|
||||
logger.debug(`[DEBUG] Calling claudeConsoleRelayService.relayRequest with accountId: ${accountId}`);
|
||||
response = await claudeConsoleRelayService.relayRequest(req.body, req.apiKey, req, res, req.headers, accountId);
|
||||
logger.debug(
|
||||
`[DEBUG] Calling claudeConsoleRelayService.relayRequest with accountId: ${accountId}`
|
||||
)
|
||||
response = await claudeConsoleRelayService.relayRequest(
|
||||
req.body,
|
||||
req.apiKey,
|
||||
req,
|
||||
res,
|
||||
req.headers,
|
||||
accountId
|
||||
)
|
||||
} else if (accountType === 'bedrock') {
|
||||
// Bedrock账号使用Bedrock转发服务
|
||||
try {
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId);
|
||||
const bedrockAccountResult = await bedrockAccountService.getAccount(accountId)
|
||||
if (!bedrockAccountResult.success) {
|
||||
throw new Error('Failed to get Bedrock account details');
|
||||
throw new Error('Failed to get Bedrock account details')
|
||||
}
|
||||
|
||||
const result = await bedrockRelayService.handleNonStreamRequest(req.body, bedrockAccountResult.data, req.headers);
|
||||
|
||||
const result = await bedrockRelayService.handleNonStreamRequest(
|
||||
req.body,
|
||||
bedrockAccountResult.data,
|
||||
req.headers
|
||||
)
|
||||
|
||||
// 构建标准响应格式
|
||||
response = {
|
||||
statusCode: result.success ? 200 : 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(result.success ? result.data : { error: result.error }),
|
||||
accountId: accountId
|
||||
};
|
||||
|
||||
accountId
|
||||
}
|
||||
|
||||
// 如果成功,添加使用统计到响应数据中
|
||||
if (result.success && result.usage) {
|
||||
const responseData = JSON.parse(response.body);
|
||||
responseData.usage = result.usage;
|
||||
response.body = JSON.stringify(responseData);
|
||||
const responseData = JSON.parse(response.body)
|
||||
responseData.usage = result.usage
|
||||
response.body = JSON.stringify(responseData)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Bedrock non-stream request failed:', error);
|
||||
logger.error('❌ Bedrock non-stream request failed:', error)
|
||||
response = {
|
||||
statusCode: 500,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ error: 'Bedrock service error', message: error.message }),
|
||||
accountId: accountId
|
||||
};
|
||||
accountId
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
logger.info('📡 Claude API response received', {
|
||||
statusCode: response.statusCode,
|
||||
headers: JSON.stringify(response.headers),
|
||||
bodyLength: response.body ? response.body.length : 0
|
||||
});
|
||||
|
||||
res.status(response.statusCode);
|
||||
|
||||
})
|
||||
|
||||
res.status(response.statusCode)
|
||||
|
||||
// 设置响应头,避免 Content-Length 和 Transfer-Encoding 冲突
|
||||
const skipHeaders = ['content-encoding', 'transfer-encoding', 'content-length'];
|
||||
Object.keys(response.headers).forEach(key => {
|
||||
const skipHeaders = ['content-encoding', 'transfer-encoding', 'content-length']
|
||||
Object.keys(response.headers).forEach((key) => {
|
||||
if (!skipHeaders.includes(key.toLowerCase())) {
|
||||
res.setHeader(key, response.headers[key]);
|
||||
res.setHeader(key, response.headers[key])
|
||||
}
|
||||
});
|
||||
|
||||
let usageRecorded = false;
|
||||
|
||||
})
|
||||
|
||||
let usageRecorded = false
|
||||
|
||||
// 尝试解析JSON响应并提取usage信息
|
||||
try {
|
||||
const jsonData = JSON.parse(response.body);
|
||||
|
||||
logger.info('📊 Parsed Claude API response:', JSON.stringify(jsonData, null, 2));
|
||||
|
||||
const jsonData = JSON.parse(response.body)
|
||||
|
||||
logger.info('📊 Parsed Claude API response:', JSON.stringify(jsonData, null, 2))
|
||||
|
||||
// 从Claude API响应中提取usage信息(完整的token分类体系)
|
||||
if (jsonData.usage && jsonData.usage.input_tokens !== undefined && jsonData.usage.output_tokens !== undefined) {
|
||||
const inputTokens = jsonData.usage.input_tokens || 0;
|
||||
const outputTokens = jsonData.usage.output_tokens || 0;
|
||||
const cacheCreateTokens = jsonData.usage.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = jsonData.usage.cache_read_input_tokens || 0;
|
||||
const model = jsonData.model || req.body.model || 'unknown';
|
||||
|
||||
if (
|
||||
jsonData.usage &&
|
||||
jsonData.usage.input_tokens !== undefined &&
|
||||
jsonData.usage.output_tokens !== undefined
|
||||
) {
|
||||
const inputTokens = jsonData.usage.input_tokens || 0
|
||||
const outputTokens = jsonData.usage.output_tokens || 0
|
||||
const cacheCreateTokens = jsonData.usage.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = jsonData.usage.cache_read_input_tokens || 0
|
||||
const model = jsonData.model || req.body.model || 'unknown'
|
||||
|
||||
// 记录真实的token使用量(包含模型信息和所有4种token以及账户ID)
|
||||
const accountId = response.accountId;
|
||||
await apiKeyService.recordUsage(req.apiKey.id, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model, accountId);
|
||||
|
||||
const { accountId: responseAccountId } = response
|
||||
await apiKeyService.recordUsage(
|
||||
req.apiKey.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
responseAccountId
|
||||
)
|
||||
|
||||
// 更新时间窗口内的token计数
|
||||
if (req.rateLimitInfo) {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
await redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens);
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`);
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
await redis.getClient().incrby(req.rateLimitInfo.tokenCountKey, totalTokens)
|
||||
logger.api(`📊 Updated rate limit token count: +${totalTokens} tokens`)
|
||||
}
|
||||
|
||||
usageRecorded = true;
|
||||
logger.api(`📊 Non-stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`);
|
||||
|
||||
usageRecorded = true
|
||||
logger.api(
|
||||
`📊 Non-stream usage recorded (real) - Model: ${model}, Input: ${inputTokens}, Output: ${outputTokens}, Cache Create: ${cacheCreateTokens}, Cache Read: ${cacheReadTokens}, Total: ${inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens} tokens`
|
||||
)
|
||||
} else {
|
||||
logger.warn('⚠️ No usage data found in Claude API JSON response');
|
||||
logger.warn('⚠️ No usage data found in Claude API JSON response')
|
||||
}
|
||||
|
||||
res.json(jsonData);
|
||||
|
||||
res.json(jsonData)
|
||||
} catch (parseError) {
|
||||
logger.warn('⚠️ Failed to parse Claude API response as JSON:', parseError.message);
|
||||
logger.info('📄 Raw response body:', response.body);
|
||||
res.send(response.body);
|
||||
logger.warn('⚠️ Failed to parse Claude API response as JSON:', parseError.message)
|
||||
logger.info('📄 Raw response body:', response.body)
|
||||
res.send(response.body)
|
||||
}
|
||||
|
||||
|
||||
// 如果没有记录usage,只记录警告,不进行估算
|
||||
if (!usageRecorded) {
|
||||
logger.warn('⚠️ No usage data recorded for non-stream request - no statistics recorded (official data only)');
|
||||
logger.warn(
|
||||
'⚠️ No usage data recorded for non-stream request - no statistics recorded (official data only)'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.api(`✅ Request completed in ${duration}ms for key: ${req.apiKey.name}`);
|
||||
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.api(`✅ Request completed in ${duration}ms for key: ${req.apiKey.name}`)
|
||||
return undefined
|
||||
} catch (error) {
|
||||
logger.error('❌ Claude relay error:', error.message, {
|
||||
code: error.code,
|
||||
stack: error.stack
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 确保在任何情况下都能返回有效的JSON响应
|
||||
if (!res.headersSent) {
|
||||
// 根据错误类型设置适当的状态码
|
||||
let statusCode = 500;
|
||||
let errorType = 'Relay service error';
|
||||
|
||||
let statusCode = 500
|
||||
let errorType = 'Relay service error'
|
||||
|
||||
if (error.message.includes('Connection reset') || error.message.includes('socket hang up')) {
|
||||
statusCode = 502;
|
||||
errorType = 'Upstream connection error';
|
||||
statusCode = 502
|
||||
errorType = 'Upstream connection error'
|
||||
} else if (error.message.includes('Connection refused')) {
|
||||
statusCode = 502;
|
||||
errorType = 'Upstream service unavailable';
|
||||
statusCode = 502
|
||||
errorType = 'Upstream service unavailable'
|
||||
} else if (error.message.includes('timeout')) {
|
||||
statusCode = 504;
|
||||
errorType = 'Upstream timeout';
|
||||
statusCode = 504
|
||||
errorType = 'Upstream timeout'
|
||||
} else if (error.message.includes('resolve') || error.message.includes('ENOTFOUND')) {
|
||||
statusCode = 502;
|
||||
errorType = 'Upstream hostname resolution failed';
|
||||
statusCode = 502
|
||||
errorType = 'Upstream hostname resolution failed'
|
||||
}
|
||||
|
||||
res.status(statusCode).json({
|
||||
|
||||
return res.status(statusCode).json({
|
||||
error: errorType,
|
||||
message: error.message || 'An unexpected error occurred',
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} else {
|
||||
// 如果响应头已经发送,尝试结束响应
|
||||
if (!res.destroyed && !res.finished) {
|
||||
res.end();
|
||||
res.end()
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 🚀 Claude API messages 端点 - /api/v1/messages
|
||||
router.post('/v1/messages', authenticateApiKey, handleMessagesRequest);
|
||||
router.post('/v1/messages', authenticateApiKey, handleMessagesRequest)
|
||||
|
||||
// 🚀 Claude API messages 端点 - /claude/v1/messages (别名)
|
||||
router.post('/claude/v1/messages', authenticateApiKey, handleMessagesRequest);
|
||||
router.post('/claude/v1/messages', authenticateApiKey, handleMessagesRequest)
|
||||
|
||||
// 📋 模型列表端点 - Claude Code 客户端需要
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
@@ -368,66 +490,65 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
owned_by: 'anthropic'
|
||||
},
|
||||
{
|
||||
id: 'claude-3-5-haiku-20241022',
|
||||
id: 'claude-3-5-haiku-20241022',
|
||||
object: 'model',
|
||||
created: 1669599635,
|
||||
owned_by: 'anthropic'
|
||||
},
|
||||
{
|
||||
id: 'claude-3-opus-20240229',
|
||||
object: 'model',
|
||||
object: 'model',
|
||||
created: 1669599635,
|
||||
owned_by: 'anthropic'
|
||||
},
|
||||
{
|
||||
id: 'claude-sonnet-4-20250514',
|
||||
object: 'model',
|
||||
created: 1669599635,
|
||||
created: 1669599635,
|
||||
owned_by: 'anthropic'
|
||||
}
|
||||
];
|
||||
|
||||
]
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Models list error:', error);
|
||||
logger.error('❌ Models list error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get models list',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🏥 健康检查端点
|
||||
router.get('/health', async (req, res) => {
|
||||
try {
|
||||
const healthStatus = await claudeRelayService.healthCheck();
|
||||
|
||||
const healthStatus = await claudeRelayService.healthCheck()
|
||||
|
||||
res.status(healthStatus.healthy ? 200 : 503).json({
|
||||
status: healthStatus.healthy ? 'healthy' : 'unhealthy',
|
||||
service: 'claude-relay-service',
|
||||
version: '1.0.0',
|
||||
...healthStatus
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Health check error:', error);
|
||||
logger.error('❌ Health check error:', error)
|
||||
res.status(503).json({
|
||||
status: 'unhealthy',
|
||||
service: 'claude-relay-service',
|
||||
error: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📊 API Key状态检查端点 - /api/v1/key-info
|
||||
router.get('/v1/key-info', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id);
|
||||
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id)
|
||||
|
||||
res.json({
|
||||
keyInfo: {
|
||||
id: req.apiKey.id,
|
||||
@@ -436,21 +557,21 @@ router.get('/v1/key-info', authenticateApiKey, async (req, res) => {
|
||||
usage
|
||||
},
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Key info error:', error);
|
||||
logger.error('❌ Key info error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get key info',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📈 使用统计端点 - /api/v1/usage
|
||||
router.get('/v1/usage', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id);
|
||||
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id)
|
||||
|
||||
res.json({
|
||||
usage,
|
||||
limits: {
|
||||
@@ -458,56 +579,56 @@ router.get('/v1/usage', authenticateApiKey, async (req, res) => {
|
||||
requests: 0 // 请求限制已移除
|
||||
},
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Usage stats error:', error);
|
||||
logger.error('❌ Usage stats error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get usage stats',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 👤 用户信息端点 - Claude Code 客户端需要
|
||||
router.get('/v1/me', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
// 返回基础用户信息
|
||||
res.json({
|
||||
id: 'user_' + req.apiKey.id,
|
||||
type: 'user',
|
||||
id: `user_${req.apiKey.id}`,
|
||||
type: 'user',
|
||||
display_name: req.apiKey.name || 'API User',
|
||||
created_at: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ User info error:', error);
|
||||
logger.error('❌ User info error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get user info',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 💰 余额/限制端点 - Claude Code 客户端需要
|
||||
router.get('/v1/organizations/:org_id/usage', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id);
|
||||
|
||||
const usage = await apiKeyService.getUsageStats(req.apiKey.id)
|
||||
|
||||
res.json({
|
||||
object: 'usage',
|
||||
data: [
|
||||
{
|
||||
type: 'credit_balance',
|
||||
type: 'credit_balance',
|
||||
credit_balance: req.apiKey.tokenLimit - (usage.totalTokens || 0)
|
||||
}
|
||||
]
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Organization usage error:', error);
|
||||
logger.error('❌ Organization usage error:', error)
|
||||
res.status(500).json({
|
||||
error: 'Failed to get usage info',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
const express = require('express');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const CostCalculator = require('../utils/costCalculator');
|
||||
const express = require('express')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const CostCalculator = require('../utils/costCalculator')
|
||||
|
||||
const router = express.Router();
|
||||
const router = express.Router()
|
||||
|
||||
// 🏠 重定向页面请求到新版 admin-spa
|
||||
router.get('/', (req, res) => {
|
||||
res.redirect(301, '/admin-next/api-stats');
|
||||
});
|
||||
res.redirect(301, '/admin-next/api-stats')
|
||||
})
|
||||
|
||||
// 🔑 获取 API Key 对应的 ID
|
||||
router.post('/api/get-key-id', async (req, res) => {
|
||||
try {
|
||||
const { apiKey } = req.body;
|
||||
|
||||
const { apiKey } = req.body
|
||||
|
||||
if (!apiKey) {
|
||||
return res.status(400).json({
|
||||
error: 'API Key is required',
|
||||
message: 'Please provide your API Key'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 基本API Key格式验证
|
||||
@@ -28,108 +28,110 @@ router.post('/api/get-key-id', async (req, res) => {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API key format',
|
||||
message: 'API key format is invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证API Key
|
||||
const validation = await apiKeyService.validateApiKey(apiKey);
|
||||
|
||||
const validation = await apiKeyService.validateApiKey(apiKey)
|
||||
|
||||
if (!validation.valid) {
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown';
|
||||
logger.security(`🔒 Invalid API key in get-key-id: ${validation.error} from ${clientIP}`);
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown'
|
||||
logger.security(`🔒 Invalid API key in get-key-id: ${validation.error} from ${clientIP}`)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid API key',
|
||||
message: validation.error
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const keyData = validation.keyData;
|
||||
|
||||
res.json({
|
||||
const { keyData } = validation
|
||||
|
||||
return res.json({
|
||||
success: true,
|
||||
data: {
|
||||
id: keyData.id
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get API key ID:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Failed to get API key ID:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: 'Failed to retrieve API key ID'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📊 用户API Key统计查询接口 - 安全的自查询接口
|
||||
router.post('/api/user-stats', async (req, res) => {
|
||||
try {
|
||||
const { apiKey, apiId } = req.body;
|
||||
|
||||
let keyData;
|
||||
let keyId;
|
||||
|
||||
const { apiKey, apiId } = req.body
|
||||
|
||||
let keyData
|
||||
let keyId
|
||||
|
||||
if (apiId) {
|
||||
// 通过 apiId 查询
|
||||
if (typeof apiId !== 'string' || !apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)) {
|
||||
if (
|
||||
typeof apiId !== 'string' ||
|
||||
!apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)
|
||||
) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API ID format',
|
||||
message: 'API ID must be a valid UUID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 直接通过 ID 获取 API Key 数据
|
||||
keyData = await redis.getApiKey(apiId);
|
||||
|
||||
keyData = await redis.getApiKey(apiId)
|
||||
|
||||
if (!keyData || Object.keys(keyData).length === 0) {
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`)
|
||||
return res.status(404).json({
|
||||
error: 'API key not found',
|
||||
message: 'The specified API key does not exist'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查是否激活
|
||||
if (keyData.isActive !== 'true') {
|
||||
return res.status(403).json({
|
||||
error: 'API key is disabled',
|
||||
message: 'This API key has been disabled'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查是否过期
|
||||
if (keyData.expiresAt && new Date() > new Date(keyData.expiresAt)) {
|
||||
return res.status(403).json({
|
||||
error: 'API key has expired',
|
||||
message: 'This API key has expired'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyId = apiId;
|
||||
|
||||
|
||||
keyId = apiId
|
||||
|
||||
// 获取使用统计
|
||||
const usage = await redis.getUsageStats(keyId);
|
||||
|
||||
const usage = await redis.getUsageStats(keyId)
|
||||
|
||||
// 获取当日费用统计
|
||||
const dailyCost = await redis.getDailyCost(keyId);
|
||||
|
||||
const dailyCost = await redis.getDailyCost(keyId)
|
||||
|
||||
// 处理数据格式,与 validateApiKey 返回的格式保持一致
|
||||
// 解析限制模型数据
|
||||
let restrictedModels = [];
|
||||
let restrictedModels = []
|
||||
try {
|
||||
restrictedModels = keyData.restrictedModels ? JSON.parse(keyData.restrictedModels) : [];
|
||||
restrictedModels = keyData.restrictedModels ? JSON.parse(keyData.restrictedModels) : []
|
||||
} catch (e) {
|
||||
restrictedModels = [];
|
||||
restrictedModels = []
|
||||
}
|
||||
|
||||
|
||||
// 解析允许的客户端数据
|
||||
let allowedClients = [];
|
||||
let allowedClients = []
|
||||
try {
|
||||
allowedClients = keyData.allowedClients ? JSON.parse(keyData.allowedClients) : [];
|
||||
allowedClients = keyData.allowedClients ? JSON.parse(keyData.allowedClients) : []
|
||||
} catch (e) {
|
||||
allowedClients = [];
|
||||
allowedClients = []
|
||||
}
|
||||
|
||||
|
||||
// 格式化 keyData
|
||||
keyData = {
|
||||
...keyData,
|
||||
@@ -140,70 +142,75 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
dailyCostLimit: parseFloat(keyData.dailyCostLimit) || 0,
|
||||
dailyCost: dailyCost || 0,
|
||||
enableModelRestriction: keyData.enableModelRestriction === 'true',
|
||||
restrictedModels: restrictedModels,
|
||||
restrictedModels,
|
||||
enableClientRestriction: keyData.enableClientRestriction === 'true',
|
||||
allowedClients: allowedClients,
|
||||
allowedClients,
|
||||
permissions: keyData.permissions || 'all',
|
||||
usage: usage // 使用完整的 usage 数据,而不是只有 total
|
||||
};
|
||||
|
||||
usage // 使用完整的 usage 数据,而不是只有 total
|
||||
}
|
||||
} else if (apiKey) {
|
||||
// 通过 apiKey 查询(保持向后兼容)
|
||||
if (typeof apiKey !== 'string' || apiKey.length < 10 || apiKey.length > 512) {
|
||||
logger.security(`🔒 Invalid API key format in user stats query from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 Invalid API key format in user stats query from ${req.ip || 'unknown'}`)
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API key format',
|
||||
message: 'API key format is invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证API Key(重用现有的验证逻辑)
|
||||
const validation = await apiKeyService.validateApiKey(apiKey);
|
||||
|
||||
const validation = await apiKeyService.validateApiKey(apiKey)
|
||||
|
||||
if (!validation.valid) {
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown';
|
||||
logger.security(`🔒 Invalid API key in user stats query: ${validation.error} from ${clientIP}`);
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown'
|
||||
logger.security(
|
||||
`🔒 Invalid API key in user stats query: ${validation.error} from ${clientIP}`
|
||||
)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid API key',
|
||||
message: validation.error
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyData = validation.keyData;
|
||||
keyId = keyData.id;
|
||||
|
||||
const { keyData: validatedKeyData } = validation
|
||||
keyData = validatedKeyData
|
||||
keyId = keyData.id
|
||||
} else {
|
||||
logger.security(`🔒 Missing API key or ID in user stats query from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 Missing API key or ID in user stats query from ${req.ip || 'unknown'}`)
|
||||
return res.status(400).json({
|
||||
error: 'API Key or ID is required',
|
||||
message: 'Please provide your API Key or API ID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 记录合法查询
|
||||
logger.api(`📊 User stats query from key: ${keyData.name} (${keyId}) from ${req.ip || 'unknown'}`);
|
||||
logger.api(
|
||||
`📊 User stats query from key: ${keyData.name} (${keyId}) from ${req.ip || 'unknown'}`
|
||||
)
|
||||
|
||||
// 获取验证结果中的完整keyData(包含isActive状态和cost信息)
|
||||
const fullKeyData = keyData;
|
||||
|
||||
const fullKeyData = keyData
|
||||
|
||||
// 计算总费用 - 使用与模型统计相同的逻辑(按模型分别计算)
|
||||
let totalCost = 0;
|
||||
let formattedCost = '$0.000000';
|
||||
|
||||
let totalCost = 0
|
||||
let formattedCost = '$0.000000'
|
||||
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 获取所有月度模型统计(与model-stats接口相同的逻辑)
|
||||
const allModelKeys = await client.keys(`usage:${keyId}:model:monthly:*:*`);
|
||||
const modelUsageMap = new Map();
|
||||
|
||||
const allModelKeys = await client.keys(`usage:${keyId}:model:monthly:*:*`)
|
||||
const modelUsageMap = new Map()
|
||||
|
||||
for (const key of allModelKeys) {
|
||||
const modelMatch = key.match(/usage:.+:model:monthly:(.+):(\d{4}-\d{2})$/);
|
||||
if (!modelMatch) continue;
|
||||
|
||||
const model = modelMatch[1];
|
||||
const data = await client.hgetall(key);
|
||||
|
||||
const modelMatch = key.match(/usage:.+:model:monthly:(.+):(\d{4}-\d{2})$/)
|
||||
if (!modelMatch) {
|
||||
continue
|
||||
}
|
||||
|
||||
const model = modelMatch[1]
|
||||
const data = await client.hgetall(key)
|
||||
|
||||
if (data && Object.keys(data).length > 0) {
|
||||
if (!modelUsageMap.has(model)) {
|
||||
modelUsageMap.set(model, {
|
||||
@@ -211,17 +218,17 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
outputTokens: 0,
|
||||
cacheCreateTokens: 0,
|
||||
cacheReadTokens: 0
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const modelUsage = modelUsageMap.get(model);
|
||||
modelUsage.inputTokens += parseInt(data.inputTokens) || 0;
|
||||
modelUsage.outputTokens += parseInt(data.outputTokens) || 0;
|
||||
modelUsage.cacheCreateTokens += parseInt(data.cacheCreateTokens) || 0;
|
||||
modelUsage.cacheReadTokens += parseInt(data.cacheReadTokens) || 0;
|
||||
|
||||
const modelUsage = modelUsageMap.get(model)
|
||||
modelUsage.inputTokens += parseInt(data.inputTokens) || 0
|
||||
modelUsage.outputTokens += parseInt(data.outputTokens) || 0
|
||||
modelUsage.cacheCreateTokens += parseInt(data.cacheCreateTokens) || 0
|
||||
modelUsage.cacheReadTokens += parseInt(data.cacheReadTokens) || 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 按模型计算费用并汇总
|
||||
for (const [model, usage] of modelUsageMap) {
|
||||
const usageData = {
|
||||
@@ -229,66 +236,65 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
output_tokens: usage.outputTokens,
|
||||
cache_creation_input_tokens: usage.cacheCreateTokens,
|
||||
cache_read_input_tokens: usage.cacheReadTokens
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(usageData, model);
|
||||
totalCost += costResult.costs.total;
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(usageData, model)
|
||||
totalCost += costResult.costs.total
|
||||
}
|
||||
|
||||
|
||||
// 如果没有模型级别的详细数据,回退到总体数据计算
|
||||
if (modelUsageMap.size === 0 && fullKeyData.usage?.total?.allTokens > 0) {
|
||||
const usage = fullKeyData.usage.total;
|
||||
const usage = fullKeyData.usage.total
|
||||
const costUsage = {
|
||||
input_tokens: usage.inputTokens || 0,
|
||||
output_tokens: usage.outputTokens || 0,
|
||||
cache_creation_input_tokens: usage.cacheCreateTokens || 0,
|
||||
cache_read_input_tokens: usage.cacheReadTokens || 0
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022');
|
||||
totalCost = costResult.costs.total;
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022')
|
||||
totalCost = costResult.costs.total
|
||||
}
|
||||
|
||||
formattedCost = CostCalculator.formatCost(totalCost);
|
||||
|
||||
|
||||
formattedCost = CostCalculator.formatCost(totalCost)
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to calculate detailed cost for key ${keyId}:`, error);
|
||||
logger.warn(`Failed to calculate detailed cost for key ${keyId}:`, error)
|
||||
// 回退到简单计算
|
||||
if (fullKeyData.usage?.total?.allTokens > 0) {
|
||||
const usage = fullKeyData.usage.total;
|
||||
const usage = fullKeyData.usage.total
|
||||
const costUsage = {
|
||||
input_tokens: usage.inputTokens || 0,
|
||||
output_tokens: usage.outputTokens || 0,
|
||||
cache_creation_input_tokens: usage.cacheCreateTokens || 0,
|
||||
cache_read_input_tokens: usage.cacheReadTokens || 0
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022');
|
||||
totalCost = costResult.costs.total;
|
||||
formattedCost = costResult.formatted.total;
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(costUsage, 'claude-3-5-sonnet-20241022')
|
||||
totalCost = costResult.costs.total
|
||||
formattedCost = costResult.formatted.total
|
||||
}
|
||||
}
|
||||
|
||||
// 获取当前使用量
|
||||
let currentWindowRequests = 0;
|
||||
let currentWindowTokens = 0;
|
||||
let currentDailyCost = 0;
|
||||
|
||||
let currentWindowRequests = 0
|
||||
let currentWindowTokens = 0
|
||||
let currentDailyCost = 0
|
||||
|
||||
try {
|
||||
// 获取当前时间窗口的请求次数和Token使用量
|
||||
if (fullKeyData.rateLimitWindow > 0) {
|
||||
const client = redis.getClientSafe();
|
||||
const requestCountKey = `rate_limit:requests:${keyId}`;
|
||||
const tokenCountKey = `rate_limit:tokens:${keyId}`;
|
||||
|
||||
currentWindowRequests = parseInt(await client.get(requestCountKey) || '0');
|
||||
currentWindowTokens = parseInt(await client.get(tokenCountKey) || '0');
|
||||
const client = redis.getClientSafe()
|
||||
const requestCountKey = `rate_limit:requests:${keyId}`
|
||||
const tokenCountKey = `rate_limit:tokens:${keyId}`
|
||||
|
||||
currentWindowRequests = parseInt((await client.get(requestCountKey)) || '0')
|
||||
currentWindowTokens = parseInt((await client.get(tokenCountKey)) || '0')
|
||||
}
|
||||
|
||||
|
||||
// 获取当日费用
|
||||
currentDailyCost = await redis.getDailyCost(keyId) || 0;
|
||||
currentDailyCost = (await redis.getDailyCost(keyId)) || 0
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to get current usage for key ${keyId}:`, error);
|
||||
logger.warn(`Failed to get current usage for key ${keyId}:`, error)
|
||||
}
|
||||
|
||||
// 构建响应数据(只返回该API Key自己的信息,确保不泄露其他信息)
|
||||
@@ -300,7 +306,7 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
createdAt: keyData.createdAt,
|
||||
expiresAt: keyData.expiresAt,
|
||||
permissions: fullKeyData.permissions,
|
||||
|
||||
|
||||
// 使用统计(使用验证结果中的完整数据)
|
||||
usage: {
|
||||
total: {
|
||||
@@ -314,10 +320,10 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
cacheReadTokens: 0
|
||||
}),
|
||||
cost: totalCost,
|
||||
formattedCost: formattedCost
|
||||
formattedCost
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
// 限制信息(显示配置和当前使用量)
|
||||
limits: {
|
||||
tokenLimit: fullKeyData.tokenLimit || 0,
|
||||
@@ -326,17 +332,23 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
rateLimitRequests: fullKeyData.rateLimitRequests || 0,
|
||||
dailyCostLimit: fullKeyData.dailyCostLimit || 0,
|
||||
// 当前使用量
|
||||
currentWindowRequests: currentWindowRequests,
|
||||
currentWindowTokens: currentWindowTokens,
|
||||
currentDailyCost: currentDailyCost
|
||||
currentWindowRequests,
|
||||
currentWindowTokens,
|
||||
currentDailyCost
|
||||
},
|
||||
|
||||
|
||||
// 绑定的账户信息(只显示ID,不显示敏感信息)
|
||||
accounts: {
|
||||
claudeAccountId: fullKeyData.claudeAccountId && fullKeyData.claudeAccountId !== '' ? fullKeyData.claudeAccountId : null,
|
||||
geminiAccountId: fullKeyData.geminiAccountId && fullKeyData.geminiAccountId !== '' ? fullKeyData.geminiAccountId : null
|
||||
claudeAccountId:
|
||||
fullKeyData.claudeAccountId && fullKeyData.claudeAccountId !== ''
|
||||
? fullKeyData.claudeAccountId
|
||||
: null,
|
||||
geminiAccountId:
|
||||
fullKeyData.geminiAccountId && fullKeyData.geminiAccountId !== ''
|
||||
? fullKeyData.geminiAccountId
|
||||
: null
|
||||
},
|
||||
|
||||
|
||||
// 模型和客户端限制信息
|
||||
restrictions: {
|
||||
enableModelRestriction: fullKeyData.enableModelRestriction || false,
|
||||
@@ -344,126 +356,137 @@ router.post('/api/user-stats', async (req, res) => {
|
||||
enableClientRestriction: fullKeyData.enableClientRestriction || false,
|
||||
allowedClients: fullKeyData.allowedClients || []
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
data: responseData
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to process user stats query:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Failed to process user stats query:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: 'Failed to retrieve API key statistics'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 📊 用户模型统计查询接口 - 安全的自查询接口
|
||||
router.post('/api/user-model-stats', async (req, res) => {
|
||||
try {
|
||||
const { apiKey, apiId, period = 'monthly' } = req.body;
|
||||
|
||||
let keyData;
|
||||
let keyId;
|
||||
|
||||
const { apiKey, apiId, period = 'monthly' } = req.body
|
||||
|
||||
let keyData
|
||||
let keyId
|
||||
|
||||
if (apiId) {
|
||||
// 通过 apiId 查询
|
||||
if (typeof apiId !== 'string' || !apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)) {
|
||||
if (
|
||||
typeof apiId !== 'string' ||
|
||||
!apiId.match(/^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$/i)
|
||||
) {
|
||||
return res.status(400).json({
|
||||
error: 'Invalid API ID format',
|
||||
message: 'API ID must be a valid UUID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 直接通过 ID 获取 API Key 数据
|
||||
keyData = await redis.getApiKey(apiId);
|
||||
|
||||
keyData = await redis.getApiKey(apiId)
|
||||
|
||||
if (!keyData || Object.keys(keyData).length === 0) {
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`);
|
||||
logger.security(`🔒 API key not found for ID: ${apiId} from ${req.ip || 'unknown'}`)
|
||||
return res.status(404).json({
|
||||
error: 'API key not found',
|
||||
message: 'The specified API key does not exist'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查是否激活
|
||||
if (keyData.isActive !== 'true') {
|
||||
return res.status(403).json({
|
||||
error: 'API key is disabled',
|
||||
message: 'This API key has been disabled'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyId = apiId;
|
||||
|
||||
|
||||
keyId = apiId
|
||||
|
||||
// 获取使用统计
|
||||
const usage = await redis.getUsageStats(keyId);
|
||||
keyData.usage = { total: usage.total };
|
||||
|
||||
const usage = await redis.getUsageStats(keyId)
|
||||
keyData.usage = { total: usage.total }
|
||||
} else if (apiKey) {
|
||||
// 通过 apiKey 查询(保持向后兼容)
|
||||
// 验证API Key
|
||||
const validation = await apiKeyService.validateApiKey(apiKey);
|
||||
|
||||
const validation = await apiKeyService.validateApiKey(apiKey)
|
||||
|
||||
if (!validation.valid) {
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown';
|
||||
logger.security(`🔒 Invalid API key in user model stats query: ${validation.error} from ${clientIP}`);
|
||||
const clientIP = req.ip || req.connection?.remoteAddress || 'unknown'
|
||||
logger.security(
|
||||
`🔒 Invalid API key in user model stats query: ${validation.error} from ${clientIP}`
|
||||
)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid API key',
|
||||
message: validation.error
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
keyData = validation.keyData;
|
||||
keyId = keyData.id;
|
||||
|
||||
const { keyData: validatedKeyData } = validation
|
||||
keyData = validatedKeyData
|
||||
keyId = keyData.id
|
||||
} else {
|
||||
logger.security(`🔒 Missing API key or ID in user model stats query from ${req.ip || 'unknown'}`);
|
||||
logger.security(
|
||||
`🔒 Missing API key or ID in user model stats query from ${req.ip || 'unknown'}`
|
||||
)
|
||||
return res.status(400).json({
|
||||
error: 'API Key or ID is required',
|
||||
message: 'Please provide your API Key or API ID'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
logger.api(`📊 User model stats query from key: ${keyData.name} (${keyId}) for period: ${period}`);
|
||||
|
||||
logger.api(
|
||||
`📊 User model stats query from key: ${keyData.name} (${keyId}) for period: ${period}`
|
||||
)
|
||||
|
||||
// 重用管理后台的模型统计逻辑,但只返回该API Key的数据
|
||||
const client = redis.getClientSafe();
|
||||
const client = redis.getClientSafe()
|
||||
// 使用与管理页面相同的时区处理逻辑
|
||||
const tzDate = redis.getDateInTimezone();
|
||||
const today = redis.getDateStringInTimezone();
|
||||
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`;
|
||||
|
||||
const pattern = period === 'daily' ?
|
||||
`usage:${keyId}:model:daily:*:${today}` :
|
||||
`usage:${keyId}:model:monthly:*:${currentMonth}`;
|
||||
|
||||
const keys = await client.keys(pattern);
|
||||
const modelStats = [];
|
||||
|
||||
const tzDate = redis.getDateInTimezone()
|
||||
const today = redis.getDateStringInTimezone()
|
||||
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`
|
||||
|
||||
const pattern =
|
||||
period === 'daily'
|
||||
? `usage:${keyId}:model:daily:*:${today}`
|
||||
: `usage:${keyId}:model:monthly:*:${currentMonth}`
|
||||
|
||||
const keys = await client.keys(pattern)
|
||||
const modelStats = []
|
||||
|
||||
for (const key of keys) {
|
||||
const match = key.match(period === 'daily' ?
|
||||
/usage:.+:model:daily:(.+):\d{4}-\d{2}-\d{2}$/ :
|
||||
/usage:.+:model:monthly:(.+):\d{4}-\d{2}$/
|
||||
);
|
||||
|
||||
if (!match) continue;
|
||||
|
||||
const model = match[1];
|
||||
const data = await client.hgetall(key);
|
||||
|
||||
const match = key.match(
|
||||
period === 'daily'
|
||||
? /usage:.+:model:daily:(.+):\d{4}-\d{2}-\d{2}$/
|
||||
: /usage:.+:model:monthly:(.+):\d{4}-\d{2}$/
|
||||
)
|
||||
|
||||
if (!match) {
|
||||
continue
|
||||
}
|
||||
|
||||
const model = match[1]
|
||||
const data = await client.hgetall(key)
|
||||
|
||||
if (data && Object.keys(data).length > 0) {
|
||||
const usage = {
|
||||
input_tokens: parseInt(data.inputTokens) || 0,
|
||||
output_tokens: parseInt(data.outputTokens) || 0,
|
||||
cache_creation_input_tokens: parseInt(data.cacheCreateTokens) || 0,
|
||||
cache_read_input_tokens: parseInt(data.cacheReadTokens) || 0
|
||||
};
|
||||
|
||||
const costData = CostCalculator.calculateCost(usage, model);
|
||||
|
||||
}
|
||||
|
||||
const costData = CostCalculator.calculateCost(usage, model)
|
||||
|
||||
modelStats.push({
|
||||
model,
|
||||
requests: parseInt(data.requests) || 0,
|
||||
@@ -475,32 +498,31 @@ router.post('/api/user-model-stats', async (req, res) => {
|
||||
costs: costData.costs,
|
||||
formatted: costData.formatted,
|
||||
pricing: costData.pricing
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有详细的模型数据,不显示历史数据以避免混淆
|
||||
// 只有在查询特定时间段时返回空数组,表示该时间段确实没有数据
|
||||
if (modelStats.length === 0) {
|
||||
logger.info(`📊 No model stats found for key ${keyId} in period ${period}`);
|
||||
logger.info(`📊 No model stats found for key ${keyId} in period ${period}`)
|
||||
}
|
||||
|
||||
// 按总token数降序排列
|
||||
modelStats.sort((a, b) => b.allTokens - a.allTokens);
|
||||
modelStats.sort((a, b) => b.allTokens - a.allTokens)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
data: modelStats,
|
||||
period: period
|
||||
});
|
||||
|
||||
period
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to process user model stats query:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Failed to process user model stats query:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: 'Failed to retrieve model statistics'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const logger = require('../utils/logger');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const geminiAccountService = require('../services/geminiAccountService');
|
||||
const { sendGeminiRequest, getAvailableModels } = require('../services/geminiRelayService');
|
||||
const crypto = require('crypto');
|
||||
const sessionHelper = require('../utils/sessionHelper');
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const express = require('express')
|
||||
const router = express.Router()
|
||||
const logger = require('../utils/logger')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const geminiAccountService = require('../services/geminiAccountService')
|
||||
const { sendGeminiRequest, getAvailableModels } = require('../services/geminiRelayService')
|
||||
const crypto = require('crypto')
|
||||
const sessionHelper = require('../utils/sessionHelper')
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
// const { OAuth2Client } = require('google-auth-library'); // OAuth2Client is not used in this file
|
||||
|
||||
// 生成会话哈希
|
||||
@@ -16,24 +16,26 @@ function generateSessionHash(req) {
|
||||
req.headers['user-agent'],
|
||||
req.ip,
|
||||
req.headers['x-api-key']?.substring(0, 10)
|
||||
].filter(Boolean).join(':');
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(':')
|
||||
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex');
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex')
|
||||
}
|
||||
|
||||
// 检查 API Key 权限
|
||||
function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
|
||||
const permissions = apiKeyData.permissions || 'all';
|
||||
return permissions === 'all' || permissions === requiredPermission;
|
||||
const permissions = apiKeyData.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// Gemini 消息处理端点
|
||||
router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
const startTime = Date.now();
|
||||
let abortController = null;
|
||||
const startTime = Date.now()
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
@@ -42,7 +44,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 提取请求参数
|
||||
@@ -52,7 +54,7 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
temperature = 0.7,
|
||||
max_tokens = 4096,
|
||||
stream = false
|
||||
} = req.body;
|
||||
} = req.body
|
||||
|
||||
// 验证必需参数
|
||||
if (!messages || !Array.isArray(messages) || messages.length === 0) {
|
||||
@@ -61,57 +63,58 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
message: 'Messages array is required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 生成会话哈希用于粘性会话
|
||||
const sessionHash = generateSessionHash(req);
|
||||
const sessionHash = generateSessionHash(req)
|
||||
|
||||
// 使用统一调度选择可用的 Gemini 账户(传递请求的模型)
|
||||
let accountId;
|
||||
let accountId
|
||||
try {
|
||||
const schedulerResult = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
sessionHash,
|
||||
model // 传递请求的模型进行过滤
|
||||
);
|
||||
accountId = schedulerResult.accountId;
|
||||
model // 传递请求的模型进行过滤
|
||||
)
|
||||
const { accountId: selectedAccountId } = schedulerResult
|
||||
accountId = selectedAccountId
|
||||
} catch (error) {
|
||||
logger.error('Failed to select Gemini account:', error);
|
||||
logger.error('Failed to select Gemini account:', error)
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: error.message || 'No available Gemini accounts',
|
||||
type: 'service_unavailable'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取账户详情
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
if (!account) {
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: 'Selected account not found',
|
||||
type: 'service_unavailable'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`);
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`)
|
||||
|
||||
// 标记账户被使用
|
||||
await geminiAccountService.markAccountUsed(account.id);
|
||||
await geminiAccountService.markAccountUsed(account.id)
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开连接
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('Client disconnected, aborting Gemini request');
|
||||
abortController.abort();
|
||||
logger.info('Client disconnected, aborting Gemini request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 发送请求到 Gemini
|
||||
const geminiResponse = await sendGeminiRequest({
|
||||
@@ -126,64 +129,64 @@ router.post('/messages', authenticateApiKey, async (req, res) => {
|
||||
signal: abortController.signal,
|
||||
projectId: account.projectId,
|
||||
accountId: account.id
|
||||
});
|
||||
})
|
||||
|
||||
if (stream) {
|
||||
// 设置流式响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 流式传输响应
|
||||
for await (const chunk of geminiResponse) {
|
||||
if (abortController.signal.aborted) {
|
||||
break;
|
||||
break
|
||||
}
|
||||
res.write(chunk);
|
||||
res.write(chunk)
|
||||
}
|
||||
|
||||
res.end();
|
||||
res.end()
|
||||
} else {
|
||||
// 非流式响应
|
||||
res.json(geminiResponse);
|
||||
res.json(geminiResponse)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`Gemini request completed in ${duration}ms`);
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`Gemini request completed in ${duration}ms`)
|
||||
} catch (error) {
|
||||
logger.error('Gemini request error:', error);
|
||||
logger.error('Gemini request error:', error)
|
||||
|
||||
// 处理速率限制
|
||||
if (error.status === 429) {
|
||||
if (req.apiKey && req.account) {
|
||||
await geminiAccountService.setAccountRateLimited(req.account.id, true);
|
||||
await geminiAccountService.setAccountRateLimited(req.account.id, true)
|
||||
}
|
||||
}
|
||||
|
||||
// 返回错误响应
|
||||
const status = error.status || 500;
|
||||
const status = error.status || 500
|
||||
const errorResponse = {
|
||||
error: error.error || {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
res.status(status).json(errorResponse);
|
||||
res.status(status).json(errorResponse)
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 获取可用模型列表
|
||||
router.get('/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
@@ -192,16 +195,20 @@ router.get('/models', authenticateApiKey, async (req, res) => {
|
||||
message: 'This API key does not have permission to access Gemini',
|
||||
type: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 选择账户获取模型列表
|
||||
let account = null;
|
||||
let account = null
|
||||
try {
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(apiKeyData, null, null);
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId);
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
null,
|
||||
null
|
||||
)
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error);
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error)
|
||||
}
|
||||
|
||||
if (!account) {
|
||||
@@ -216,32 +223,32 @@ router.get('/models', authenticateApiKey, async (req, res) => {
|
||||
owned_by: 'google'
|
||||
}
|
||||
]
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取模型列表
|
||||
const models = await getAvailableModels(account.accessToken, account.proxy);
|
||||
const models = await getAvailableModels(account.accessToken, account.proxy)
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Gemini models:', error);
|
||||
logger.error('Failed to get Gemini models:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 使用情况统计(与 Claude 共用)
|
||||
router.get('/usage', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const usage = req.apiKey.usage;
|
||||
const { usage } = req.apiKey
|
||||
|
||||
res.json({
|
||||
object: 'usage',
|
||||
@@ -251,22 +258,22 @@ router.get('/usage', authenticateApiKey, async (req, res) => {
|
||||
daily_requests: usage.daily.requests,
|
||||
monthly_tokens: usage.monthly.tokens,
|
||||
monthly_requests: usage.monthly.requests
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get usage stats:', error);
|
||||
logger.error('Failed to get usage stats:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve usage statistics',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// API Key 信息(与 Claude 共用)
|
||||
router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const keyData = req.apiKey;
|
||||
const keyData = req.apiKey
|
||||
|
||||
res.json({
|
||||
id: keyData.id,
|
||||
@@ -274,9 +281,10 @@ router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
permissions: keyData.permissions || 'all',
|
||||
token_limit: keyData.tokenLimit,
|
||||
tokens_used: keyData.usage.total.tokens,
|
||||
tokens_remaining: keyData.tokenLimit > 0
|
||||
? Math.max(0, keyData.tokenLimit - keyData.usage.total.tokens)
|
||||
: null,
|
||||
tokens_remaining:
|
||||
keyData.tokenLimit > 0
|
||||
? Math.max(0, keyData.tokenLimit - keyData.usage.total.tokens)
|
||||
: null,
|
||||
rate_limit: {
|
||||
window: keyData.rateLimitWindow,
|
||||
requests: keyData.rateLimitRequests
|
||||
@@ -286,88 +294,105 @@ router.get('/key-info', authenticateApiKey, async (req, res) => {
|
||||
enabled: keyData.enableModelRestriction,
|
||||
models: keyData.restrictedModels
|
||||
}
|
||||
});
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get key info:', error);
|
||||
logger.error('Failed to get key info:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve API key information',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 共用的 loadCodeAssist 处理函数
|
||||
async function handleLoadCodeAssist(req, res) {
|
||||
try {
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
|
||||
logger.info(`accessToken: ${accessToken}`);
|
||||
const requestedModel = req.body.model
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId)
|
||||
logger.info(`accessToken: ${accessToken}`)
|
||||
|
||||
const { metadata, cloudaicompanionProject } = req.body;
|
||||
const { metadata, cloudaicompanionProject } = req.body
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`LoadCodeAssist request (${version})`, {
|
||||
metadata: metadata || {},
|
||||
cloudaicompanionProject: cloudaicompanionProject || null,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject);
|
||||
res.json(response);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const response = await geminiAccountService.loadCodeAssist(client, cloudaicompanionProject)
|
||||
res.json(response)
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in loadCodeAssist endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in loadCodeAssist endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 共用的 onboardUser 处理函数
|
||||
async function handleOnboardUser(req, res) {
|
||||
try {
|
||||
const { tierId, cloudaicompanionProject, metadata } = req.body;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const { tierId, cloudaicompanionProject, metadata } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 使用统一调度选择账号(传递请求的模型)
|
||||
const requestedModel = req.body.model;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, requestedModel);
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
|
||||
const requestedModel = req.body.model
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
requestedModel
|
||||
)
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId)
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`OnboardUser request (${version})`, {
|
||||
tierId: tierId || 'not provided',
|
||||
cloudaicompanionProject: cloudaicompanionProject || null,
|
||||
metadata: metadata || {},
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
|
||||
// 如果提供了完整参数,直接调用onboardUser
|
||||
if (tierId && metadata) {
|
||||
const response = await geminiAccountService.onboardUser(client, tierId, cloudaicompanionProject, metadata);
|
||||
res.json(response);
|
||||
const response = await geminiAccountService.onboardUser(
|
||||
client,
|
||||
tierId,
|
||||
cloudaicompanionProject,
|
||||
metadata
|
||||
)
|
||||
res.json(response)
|
||||
} else {
|
||||
// 否则执行完整的setupUser流程
|
||||
const response = await geminiAccountService.setupUser(client, cloudaicompanionProject, metadata);
|
||||
res.json(response);
|
||||
const response = await geminiAccountService.setupUser(
|
||||
client,
|
||||
cloudaicompanionProject,
|
||||
metadata
|
||||
)
|
||||
res.json(response)
|
||||
}
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in onboardUser endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in onboardUser endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: error.message
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,9 +400,9 @@ async function handleOnboardUser(req, res) {
|
||||
async function handleCountTokens(req, res) {
|
||||
try {
|
||||
// 处理请求体结构,支持直接 contents 或 request.contents
|
||||
const requestData = req.body.request || req.body;
|
||||
const { contents, model = 'gemini-2.0-flash-exp' } = requestData;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const requestData = req.body.request || req.body
|
||||
const { contents, model = 'gemini-2.0-flash-exp' } = requestData
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 验证必需参数
|
||||
if (!contents || !Array.isArray(contents)) {
|
||||
@@ -386,49 +411,54 @@ async function handleCountTokens(req, res) {
|
||||
message: 'Contents array is required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 使用统一调度选择账号
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId);
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
const { accessToken, refreshToken } = await geminiAccountService.getAccount(accountId)
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`CountTokens request (${version})`, {
|
||||
model: model,
|
||||
model,
|
||||
contentsLength: contents.length,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const response = await geminiAccountService.countTokens(client, contents, model);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const response = await geminiAccountService.countTokens(client, contents, model)
|
||||
|
||||
res.json(response);
|
||||
res.json(response)
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in countTokens endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in countTokens endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 共用的 generateContent 处理函数
|
||||
async function handleGenerateContent(req, res) {
|
||||
try {
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 处理不同格式的请求
|
||||
let actualRequestData = requestData;
|
||||
let actualRequestData = requestData
|
||||
if (!requestData) {
|
||||
if (req.body.messages) {
|
||||
// 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象
|
||||
actualRequestData = {
|
||||
contents: req.body.messages.map(msg => ({
|
||||
contents: req.body.messages.map((msg) => ({
|
||||
role: msg.role === 'assistant' ? 'model' : msg.role,
|
||||
parts: [{ text: msg.content }]
|
||||
})),
|
||||
@@ -438,10 +468,10 @@ async function handleGenerateContent(req, res) {
|
||||
topP: req.body.top_p !== undefined ? req.body.top_p : 0.95,
|
||||
topK: req.body.top_k !== undefined ? req.body.top_k : 40
|
||||
}
|
||||
};
|
||||
}
|
||||
} else if (req.body.contents) {
|
||||
// 直接的 Gemini 格式请求(没有 request 包装)
|
||||
actualRequestData = req.body;
|
||||
actualRequestData = req.body
|
||||
}
|
||||
}
|
||||
|
||||
@@ -452,35 +482,39 @@ async function handleGenerateContent(req, res) {
|
||||
message: 'Request contents are required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 使用统一调度选择账号
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const { accessToken, refreshToken } = account;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
const { accessToken, refreshToken } = account
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`GenerateContent request (${version})`, {
|
||||
model: model,
|
||||
model,
|
||||
userPromptId: user_prompt_id,
|
||||
projectId: project || account.projectId,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const response = await geminiAccountService.generateContent(
|
||||
client,
|
||||
{ model, request: actualRequestData },
|
||||
user_prompt_id,
|
||||
project || account.projectId,
|
||||
req.apiKey?.id // 使用 API Key ID 作为 session ID
|
||||
);
|
||||
)
|
||||
|
||||
// 记录使用统计
|
||||
if (response?.response?.usageMetadata) {
|
||||
try {
|
||||
const usage = response.response.usageMetadata;
|
||||
const usage = response.response.usageMetadata
|
||||
await apiKeyService.recordUsage(
|
||||
req.apiKey.id,
|
||||
usage.promptTokenCount || 0,
|
||||
@@ -489,42 +523,45 @@ async function handleGenerateContent(req, res) {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini usage - Input: ${usage.promptTokenCount}, Output: ${usage.candidatesTokenCount}, Total: ${usage.totalTokenCount}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini usage - Input: ${usage.promptTokenCount}, Output: ${usage.candidatesTokenCount}, Total: ${usage.totalTokenCount}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
res.json(response);
|
||||
res.json(response)
|
||||
} catch (error) {
|
||||
console.log(321, error.response);
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in generateContent endpoint (${version})`, { error: error.message });
|
||||
console.log(321, error.response)
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in generateContent endpoint (${version})`, { error: error.message })
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 共用的 streamGenerateContent 处理函数
|
||||
async function handleStreamGenerateContent(req, res) {
|
||||
let abortController = null;
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body;
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body);
|
||||
const { model, project, user_prompt_id, request: requestData } = req.body
|
||||
const sessionHash = sessionHelper.generateSessionHash(req.body)
|
||||
|
||||
// 处理不同格式的请求
|
||||
let actualRequestData = requestData;
|
||||
let actualRequestData = requestData
|
||||
if (!requestData) {
|
||||
if (req.body.messages) {
|
||||
// 这是 OpenAI 格式的请求,构建 Gemini 格式的 request 对象
|
||||
actualRequestData = {
|
||||
contents: req.body.messages.map(msg => ({
|
||||
contents: req.body.messages.map((msg) => ({
|
||||
role: msg.role === 'assistant' ? 'model' : msg.role,
|
||||
parts: [{ text: msg.content }]
|
||||
})),
|
||||
@@ -534,10 +571,10 @@ async function handleStreamGenerateContent(req, res) {
|
||||
topP: req.body.top_p !== undefined ? req.body.top_p : 0.95,
|
||||
topK: req.body.top_k !== undefined ? req.body.top_k : 40
|
||||
}
|
||||
};
|
||||
}
|
||||
} else if (req.body.contents) {
|
||||
// 直接的 Gemini 格式请求(没有 request 包装)
|
||||
actualRequestData = req.body;
|
||||
actualRequestData = req.body
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,34 +585,38 @@ async function handleStreamGenerateContent(req, res) {
|
||||
message: 'Request contents are required',
|
||||
type: 'invalid_request_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 使用统一调度选择账号
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(req.apiKey, sessionHash, model);
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const { accessToken, refreshToken } = account;
|
||||
const { accountId } = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
req.apiKey,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
const { accessToken, refreshToken } = account
|
||||
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.info(`StreamGenerateContent request (${version})`, {
|
||||
model: model,
|
||||
model,
|
||||
userPromptId: user_prompt_id,
|
||||
projectId: project || account.projectId,
|
||||
apiKeyId: req.apiKey?.id || 'unknown'
|
||||
});
|
||||
})
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开连接
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('Client disconnected, aborting stream request');
|
||||
abortController.abort();
|
||||
logger.info('Client disconnected, aborting stream request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(accessToken, refreshToken)
|
||||
const streamResponse = await geminiAccountService.generateContentStream(
|
||||
client,
|
||||
{ model, request: actualRequestData },
|
||||
@@ -583,48 +624,48 @@ async function handleStreamGenerateContent(req, res) {
|
||||
project || account.projectId,
|
||||
req.apiKey?.id, // 使用 API Key ID 作为 session ID
|
||||
abortController.signal // 传递中止信号
|
||||
);
|
||||
)
|
||||
|
||||
// 设置 SSE 响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 处理流式响应并捕获usage数据
|
||||
let buffer = '';
|
||||
let buffer = ''
|
||||
let totalUsage = {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
let usageReported = false;
|
||||
}
|
||||
const usageReported = false
|
||||
|
||||
streamResponse.on('data', (chunk) => {
|
||||
try {
|
||||
const chunkStr = chunk.toString();
|
||||
|
||||
const chunkStr = chunk.toString()
|
||||
|
||||
// 直接转发数据到客户端
|
||||
if (!res.destroyed) {
|
||||
res.write(chunkStr);
|
||||
res.write(chunkStr)
|
||||
}
|
||||
|
||||
// 同时解析数据以捕获usage信息
|
||||
buffer += chunkStr;
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || '';
|
||||
buffer += chunkStr
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ') && line.length > 6) {
|
||||
try {
|
||||
const jsonStr = line.slice(6);
|
||||
const jsonStr = line.slice(6)
|
||||
if (jsonStr && jsonStr !== '[DONE]') {
|
||||
const data = JSON.parse(jsonStr);
|
||||
|
||||
const data = JSON.parse(jsonStr)
|
||||
|
||||
// 从响应中提取usage数据
|
||||
if (data.response?.usageMetadata) {
|
||||
totalUsage = data.response.usageMetadata;
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage);
|
||||
totalUsage = data.response.usageMetadata
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -633,13 +674,13 @@ async function handleStreamGenerateContent(req, res) {
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing stream chunk:', error);
|
||||
logger.error('Error processing stream chunk:', error)
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
streamResponse.on('end', async () => {
|
||||
logger.info('Stream completed successfully');
|
||||
|
||||
logger.info('Stream completed successfully')
|
||||
|
||||
// 记录使用统计
|
||||
if (!usageReported && totalUsage.totalTokenCount > 0) {
|
||||
try {
|
||||
@@ -651,33 +692,34 @@ async function handleStreamGenerateContent(req, res) {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
res.end();
|
||||
});
|
||||
|
||||
res.end()
|
||||
})
|
||||
|
||||
streamResponse.on('error', (error) => {
|
||||
logger.error('Stream error:', error);
|
||||
logger.error('Stream error:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Stream error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
} else {
|
||||
res.end();
|
||||
res.end()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal';
|
||||
logger.error(`Error in streamGenerateContent endpoint (${version})`, { error: error.message });
|
||||
const version = req.path.includes('v1beta') ? 'v1beta' : 'v1internal'
|
||||
logger.error(`Error in streamGenerateContent endpoint (${version})`, { error: error.message })
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
@@ -685,29 +727,38 @@ async function handleStreamGenerateContent(req, res) {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 注册所有路由端点
|
||||
// v1internal 版本的端点
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist);
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, handleOnboardUser);
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, handleCountTokens);
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, handleGenerateContent);
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent);
|
||||
router.post('/v1internal\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist)
|
||||
router.post('/v1internal\\:onboardUser', authenticateApiKey, handleOnboardUser)
|
||||
router.post('/v1internal\\:countTokens', authenticateApiKey, handleCountTokens)
|
||||
router.post('/v1internal\\:generateContent', authenticateApiKey, handleGenerateContent)
|
||||
router.post('/v1internal\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent)
|
||||
|
||||
// v1beta 版本的端点 - 支持动态模型名称
|
||||
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist);
|
||||
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, handleOnboardUser);
|
||||
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, handleCountTokens);
|
||||
router.post('/v1beta/models/:modelName\\:generateContent', authenticateApiKey, handleGenerateContent);
|
||||
router.post('/v1beta/models/:modelName\\:streamGenerateContent', authenticateApiKey, handleStreamGenerateContent);
|
||||
router.post('/v1beta/models/:modelName\\:loadCodeAssist', authenticateApiKey, handleLoadCodeAssist)
|
||||
router.post('/v1beta/models/:modelName\\:onboardUser', authenticateApiKey, handleOnboardUser)
|
||||
router.post('/v1beta/models/:modelName\\:countTokens', authenticateApiKey, handleCountTokens)
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:generateContent',
|
||||
authenticateApiKey,
|
||||
handleGenerateContent
|
||||
)
|
||||
router.post(
|
||||
'/v1beta/models/:modelName\\:streamGenerateContent',
|
||||
authenticateApiKey,
|
||||
handleStreamGenerateContent
|
||||
)
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -3,41 +3,41 @@
|
||||
* 提供 OpenAI 格式的 API 接口,内部转发到 Claude
|
||||
*/
|
||||
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const logger = require('../utils/logger');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const claudeRelayService = require('../services/claudeRelayService');
|
||||
const openaiToClaude = require('../services/openaiToClaude');
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler');
|
||||
const claudeCodeHeadersService = require('../services/claudeCodeHeadersService');
|
||||
const sessionHelper = require('../utils/sessionHelper');
|
||||
const express = require('express')
|
||||
const router = express.Router()
|
||||
const fs = require('fs')
|
||||
const path = require('path')
|
||||
const logger = require('../utils/logger')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const claudeRelayService = require('../services/claudeRelayService')
|
||||
const openaiToClaude = require('../services/openaiToClaude')
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
const unifiedClaudeScheduler = require('../services/unifiedClaudeScheduler')
|
||||
const claudeCodeHeadersService = require('../services/claudeCodeHeadersService')
|
||||
const sessionHelper = require('../utils/sessionHelper')
|
||||
|
||||
// 加载模型定价数据
|
||||
let modelPricingData = {};
|
||||
let modelPricingData = {}
|
||||
try {
|
||||
const pricingPath = path.join(__dirname, '../../data/model_pricing.json');
|
||||
const pricingContent = fs.readFileSync(pricingPath, 'utf8');
|
||||
modelPricingData = JSON.parse(pricingContent);
|
||||
logger.info('✅ Model pricing data loaded successfully');
|
||||
const pricingPath = path.join(__dirname, '../../data/model_pricing.json')
|
||||
const pricingContent = fs.readFileSync(pricingPath, 'utf8')
|
||||
modelPricingData = JSON.parse(pricingContent)
|
||||
logger.info('✅ Model pricing data loaded successfully')
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to load model pricing data:', error);
|
||||
logger.error('❌ Failed to load model pricing data:', error)
|
||||
}
|
||||
|
||||
// 🔧 辅助函数:检查 API Key 权限
|
||||
function checkPermissions(apiKeyData, requiredPermission = 'claude') {
|
||||
const permissions = apiKeyData.permissions || 'all';
|
||||
return permissions === 'all' || permissions === requiredPermission;
|
||||
const permissions = apiKeyData.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 📋 OpenAI 兼容的模型列表端点
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'claude')) {
|
||||
return res.status(403).json({
|
||||
@@ -46,9 +46,9 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// Claude 模型列表 - 只返回 opus-4 和 sonnet-4
|
||||
let models = [
|
||||
{
|
||||
@@ -63,36 +63,36 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
created: 1736726400, // 2025-01-13
|
||||
owned_by: 'anthropic'
|
||||
}
|
||||
];
|
||||
|
||||
]
|
||||
|
||||
// 如果启用了模型限制,过滤模型列表
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels?.length > 0) {
|
||||
models = models.filter(model => apiKeyData.restrictedModels.includes(model.id));
|
||||
models = models.filter((model) => apiKeyData.restrictedModels.includes(model.id))
|
||||
}
|
||||
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get OpenAI-Claude models:', error);
|
||||
logger.error('❌ Failed to get OpenAI-Claude models:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 📄 OpenAI 兼容的模型详情端点
|
||||
router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const modelId = req.params.model;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
const modelId = req.params.model
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'claude')) {
|
||||
return res.status(403).json({
|
||||
@@ -101,9 +101,9 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels?.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(modelId)) {
|
||||
@@ -113,16 +113,16 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_found'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 从 model_pricing.json 获取模型信息
|
||||
const modelData = modelPricingData[modelId];
|
||||
|
||||
const modelData = modelPricingData[modelId]
|
||||
|
||||
// 构建标准 OpenAI 格式的模型响应
|
||||
let modelInfo;
|
||||
|
||||
let modelInfo
|
||||
|
||||
if (modelData) {
|
||||
// 如果在 pricing 文件中找到了模型
|
||||
modelInfo = {
|
||||
@@ -133,7 +133,7 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
permission: [],
|
||||
root: modelId,
|
||||
parent: null
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// 如果没找到,返回默认信息(但仍保持正确格式)
|
||||
modelInfo = {
|
||||
@@ -144,28 +144,28 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
permission: [],
|
||||
root: modelId,
|
||||
parent: null
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
res.json(modelInfo);
|
||||
|
||||
|
||||
res.json(modelInfo)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get model details:', error);
|
||||
logger.error('❌ Failed to get model details:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve model details',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// 🔧 处理聊天完成请求的核心函数
|
||||
async function handleChatCompletion(req, res, apiKeyData) {
|
||||
const startTime = Date.now();
|
||||
let abortController = null;
|
||||
|
||||
const startTime = Date.now()
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'claude')) {
|
||||
@@ -175,20 +175,20 @@ async function handleChatCompletion(req, res, apiKeyData) {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 记录原始请求
|
||||
logger.debug('📥 Received OpenAI format request:', {
|
||||
model: req.body.model,
|
||||
messageCount: req.body.messages?.length,
|
||||
stream: req.body.stream,
|
||||
maxTokens: req.body.max_tokens
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 转换 OpenAI 请求为 Claude 格式
|
||||
const claudeRequest = openaiToClaude.convertRequest(req.body);
|
||||
|
||||
const claudeRequest = openaiToClaude.convertRequest(req.body)
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels?.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(claudeRequest.model)) {
|
||||
@@ -198,114 +198,119 @@ async function handleChatCompletion(req, res, apiKeyData) {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_allowed'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(claudeRequest);
|
||||
|
||||
const sessionHash = sessionHelper.generateSessionHash(claudeRequest)
|
||||
|
||||
// 选择可用的Claude账户
|
||||
const accountSelection = await unifiedClaudeScheduler.selectAccountForApiKey(apiKeyData, sessionHash, claudeRequest.model);
|
||||
const accountId = accountSelection.accountId;
|
||||
|
||||
const accountSelection = await unifiedClaudeScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
sessionHash,
|
||||
claudeRequest.model
|
||||
)
|
||||
const { accountId } = accountSelection
|
||||
|
||||
// 获取该账号存储的 Claude Code headers
|
||||
const claudeCodeHeaders = await claudeCodeHeadersService.getAccountHeaders(accountId);
|
||||
|
||||
const claudeCodeHeaders = await claudeCodeHeadersService.getAccountHeaders(accountId)
|
||||
|
||||
logger.debug(`📋 Using Claude Code headers for account ${accountId}:`, {
|
||||
userAgent: claudeCodeHeaders['user-agent']
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 处理流式请求
|
||||
if (claudeRequest.stream) {
|
||||
logger.info(`🌊 Processing OpenAI stream request for model: ${req.body.model}`);
|
||||
|
||||
logger.info(`🌊 Processing OpenAI stream request for model: ${req.body.model}`)
|
||||
|
||||
// 设置 SSE 响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('🔌 Client disconnected, aborting Claude request');
|
||||
abortController.abort();
|
||||
logger.info('🔌 Client disconnected, aborting Claude request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 使用转换后的响应流 (使用 OAuth-only beta header,添加 Claude Code 必需的 headers)
|
||||
await claudeRelayService.relayStreamRequestWithUsageCapture(
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
res,
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
res,
|
||||
claudeCodeHeaders,
|
||||
(usage) => {
|
||||
// 记录使用统计
|
||||
if (usage && usage.input_tokens !== undefined && usage.output_tokens !== undefined) {
|
||||
const inputTokens = usage.input_tokens || 0;
|
||||
const outputTokens = usage.output_tokens || 0;
|
||||
const cacheCreateTokens = usage.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0;
|
||||
const model = usage.model || claudeRequest.model;
|
||||
|
||||
apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
accountId
|
||||
).catch(error => {
|
||||
logger.error('❌ Failed to record usage:', error);
|
||||
});
|
||||
const inputTokens = usage.input_tokens || 0
|
||||
const outputTokens = usage.output_tokens || 0
|
||||
const cacheCreateTokens = usage.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0
|
||||
const model = usage.model || claudeRequest.model
|
||||
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
apiKeyData.id,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model,
|
||||
accountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record usage:', error)
|
||||
})
|
||||
}
|
||||
},
|
||||
// 流转换器
|
||||
(() => {
|
||||
// 为每个请求创建独立的会话ID
|
||||
const sessionId = `chatcmpl-${Math.random().toString(36).substring(2, 15)}${Math.random().toString(36).substring(2, 15)}`;
|
||||
return (chunk) => {
|
||||
return openaiToClaude.convertStreamChunk(chunk, req.body.model, sessionId);
|
||||
};
|
||||
const sessionId = `chatcmpl-${Math.random().toString(36).substring(2, 15)}${Math.random().toString(36).substring(2, 15)}`
|
||||
return (chunk) => openaiToClaude.convertStreamChunk(chunk, req.body.model, sessionId)
|
||||
})(),
|
||||
{ betaHeader: 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14' }
|
||||
);
|
||||
|
||||
{
|
||||
betaHeader:
|
||||
'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14'
|
||||
}
|
||||
)
|
||||
} else {
|
||||
// 非流式请求
|
||||
logger.info(`📄 Processing OpenAI non-stream request for model: ${req.body.model}`);
|
||||
|
||||
logger.info(`📄 Processing OpenAI non-stream request for model: ${req.body.model}`)
|
||||
|
||||
// 发送请求到 Claude (使用 OAuth-only beta header,添加 Claude Code 必需的 headers)
|
||||
const claudeResponse = await claudeRelayService.relayRequest(
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
req,
|
||||
res,
|
||||
claudeRequest,
|
||||
apiKeyData,
|
||||
req,
|
||||
res,
|
||||
claudeCodeHeaders,
|
||||
{ betaHeader: 'oauth-2025-04-20' }
|
||||
);
|
||||
|
||||
)
|
||||
|
||||
// 解析 Claude 响应
|
||||
let claudeData;
|
||||
let claudeData
|
||||
try {
|
||||
claudeData = JSON.parse(claudeResponse.body);
|
||||
claudeData = JSON.parse(claudeResponse.body)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to parse Claude response:', error);
|
||||
logger.error('❌ Failed to parse Claude response:', error)
|
||||
return res.status(502).json({
|
||||
error: {
|
||||
message: 'Invalid response from Claude API',
|
||||
type: 'api_error',
|
||||
code: 'invalid_response'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 处理错误响应
|
||||
if (claudeResponse.statusCode >= 400) {
|
||||
return res.status(claudeResponse.statusCode).json({
|
||||
@@ -314,64 +319,66 @@ async function handleChatCompletion(req, res, apiKeyData) {
|
||||
type: claudeData.error?.type || 'api_error',
|
||||
code: claudeData.error?.code || 'unknown_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 转换为 OpenAI 格式
|
||||
const openaiResponse = openaiToClaude.convertResponse(claudeData, req.body.model);
|
||||
|
||||
const openaiResponse = openaiToClaude.convertResponse(claudeData, req.body.model)
|
||||
|
||||
// 记录使用统计
|
||||
if (claudeData.usage) {
|
||||
const usage = claudeData.usage;
|
||||
apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
usage.input_tokens || 0,
|
||||
usage.output_tokens || 0,
|
||||
usage.cache_creation_input_tokens || 0,
|
||||
usage.cache_read_input_tokens || 0,
|
||||
claudeRequest.model,
|
||||
accountId
|
||||
).catch(error => {
|
||||
logger.error('❌ Failed to record usage:', error);
|
||||
});
|
||||
const { usage } = claudeData
|
||||
apiKeyService
|
||||
.recordUsage(
|
||||
apiKeyData.id,
|
||||
usage.input_tokens || 0,
|
||||
usage.output_tokens || 0,
|
||||
usage.cache_creation_input_tokens || 0,
|
||||
usage.cache_read_input_tokens || 0,
|
||||
claudeRequest.model,
|
||||
accountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record usage:', error)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 返回 OpenAI 格式响应
|
||||
res.json(openaiResponse);
|
||||
res.json(openaiResponse)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`✅ OpenAI-Claude request completed in ${duration}ms`);
|
||||
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`✅ OpenAI-Claude request completed in ${duration}ms`)
|
||||
} catch (error) {
|
||||
logger.error('❌ OpenAI-Claude request error:', error);
|
||||
|
||||
const status = error.status || 500;
|
||||
logger.error('❌ OpenAI-Claude request error:', error)
|
||||
|
||||
const status = error.status || 500
|
||||
res.status(status).json({
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 🚀 OpenAI 兼容的聊天完成端点
|
||||
router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
await handleChatCompletion(req, res, req.apiKey);
|
||||
});
|
||||
await handleChatCompletion(req, res, req.apiKey)
|
||||
})
|
||||
|
||||
// 🔧 OpenAI 兼容的 completions 端点(传统格式,转换为 chat 格式)
|
||||
router.post('/v1/completions', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 验证必需参数
|
||||
if (!req.body.prompt) {
|
||||
return res.status(400).json({
|
||||
@@ -380,11 +387,11 @@ router.post('/v1/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'invalid_request'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 将传统 completions 格式转换为 chat 格式
|
||||
const originalBody = req.body;
|
||||
const originalBody = req.body
|
||||
req.body = {
|
||||
model: originalBody.model,
|
||||
messages: [
|
||||
@@ -403,21 +410,21 @@ router.post('/v1/completions', authenticateApiKey, async (req, res) => {
|
||||
frequency_penalty: originalBody.frequency_penalty,
|
||||
logit_bias: originalBody.logit_bias,
|
||||
user: originalBody.user
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 使用共享的处理函数
|
||||
await handleChatCompletion(req, res, apiKeyData);
|
||||
|
||||
await handleChatCompletion(req, res, apiKeyData)
|
||||
} catch (error) {
|
||||
logger.error('❌ OpenAI completions error:', error);
|
||||
logger.error('❌ OpenAI completions error:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to process completion request',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
const express = require('express');
|
||||
const router = express.Router();
|
||||
const logger = require('../utils/logger');
|
||||
const { authenticateApiKey } = require('../middleware/auth');
|
||||
const geminiAccountService = require('../services/geminiAccountService');
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler');
|
||||
const { getAvailableModels } = require('../services/geminiRelayService');
|
||||
const crypto = require('crypto');
|
||||
const express = require('express')
|
||||
const router = express.Router()
|
||||
const logger = require('../utils/logger')
|
||||
const { authenticateApiKey } = require('../middleware/auth')
|
||||
const geminiAccountService = require('../services/geminiAccountService')
|
||||
const unifiedGeminiScheduler = require('../services/unifiedGeminiScheduler')
|
||||
const { getAvailableModels } = require('../services/geminiRelayService')
|
||||
const crypto = require('crypto')
|
||||
|
||||
// 生成会话哈希
|
||||
function generateSessionHash(req) {
|
||||
@@ -13,167 +13,182 @@ function generateSessionHash(req) {
|
||||
req.headers['user-agent'],
|
||||
req.ip,
|
||||
req.headers['authorization']?.substring(0, 20)
|
||||
].filter(Boolean).join(':');
|
||||
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex');
|
||||
]
|
||||
.filter(Boolean)
|
||||
.join(':')
|
||||
|
||||
return crypto.createHash('sha256').update(sessionData).digest('hex')
|
||||
}
|
||||
|
||||
// 检查 API Key 权限
|
||||
function checkPermissions(apiKeyData, requiredPermission = 'gemini') {
|
||||
const permissions = apiKeyData.permissions || 'all';
|
||||
return permissions === 'all' || permissions === requiredPermission;
|
||||
const permissions = apiKeyData.permissions || 'all'
|
||||
return permissions === 'all' || permissions === requiredPermission
|
||||
}
|
||||
|
||||
// 转换 OpenAI 消息格式到 Gemini 格式
|
||||
function convertMessagesToGemini(messages) {
|
||||
const contents = [];
|
||||
let systemInstruction = '';
|
||||
|
||||
const contents = []
|
||||
let systemInstruction = ''
|
||||
|
||||
// 辅助函数:提取文本内容
|
||||
function extractTextContent(content) {
|
||||
// 处理 null 或 undefined
|
||||
if (content == null) {
|
||||
return '';
|
||||
if (content === null || content === undefined) {
|
||||
return ''
|
||||
}
|
||||
|
||||
|
||||
// 处理字符串
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
return content
|
||||
}
|
||||
|
||||
|
||||
// 处理数组格式的内容
|
||||
if (Array.isArray(content)) {
|
||||
return content.map(item => {
|
||||
if (item == null) return '';
|
||||
if (typeof item === 'string') {
|
||||
return item;
|
||||
}
|
||||
if (typeof item === 'object') {
|
||||
// 处理 {type: 'text', text: '...'} 格式
|
||||
if (item.type === 'text' && item.text) {
|
||||
return item.text;
|
||||
return content
|
||||
.map((item) => {
|
||||
if (item === null || item === undefined) {
|
||||
return ''
|
||||
}
|
||||
// 处理 {text: '...'} 格式
|
||||
if (item.text) {
|
||||
return item.text;
|
||||
if (typeof item === 'string') {
|
||||
return item
|
||||
}
|
||||
// 处理嵌套的对象或数组
|
||||
if (item.content) {
|
||||
return extractTextContent(item.content);
|
||||
if (typeof item === 'object') {
|
||||
// 处理 {type: 'text', text: '...'} 格式
|
||||
if (item.type === 'text' && item.text) {
|
||||
return item.text
|
||||
}
|
||||
// 处理 {text: '...'} 格式
|
||||
if (item.text) {
|
||||
return item.text
|
||||
}
|
||||
// 处理嵌套的对象或数组
|
||||
if (item.content) {
|
||||
return extractTextContent(item.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}).join('');
|
||||
return ''
|
||||
})
|
||||
.join('')
|
||||
}
|
||||
|
||||
|
||||
// 处理对象格式的内容
|
||||
if (typeof content === 'object') {
|
||||
// 处理 {text: '...'} 格式
|
||||
if (content.text) {
|
||||
return content.text;
|
||||
return content.text
|
||||
}
|
||||
// 处理 {content: '...'} 格式
|
||||
if (content.content) {
|
||||
return extractTextContent(content.content);
|
||||
return extractTextContent(content.content)
|
||||
}
|
||||
// 处理 {parts: [{text: '...'}]} 格式
|
||||
if (content.parts && Array.isArray(content.parts)) {
|
||||
return content.parts.map(part => {
|
||||
if (part && part.text) {
|
||||
return part.text;
|
||||
}
|
||||
return '';
|
||||
}).join('');
|
||||
return content.parts
|
||||
.map((part) => {
|
||||
if (part && part.text) {
|
||||
return part.text
|
||||
}
|
||||
return ''
|
||||
})
|
||||
.join('')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 最后的后备选项:只有在内容确实不为空且有意义时才转换为字符串
|
||||
if (content !== undefined && content !== null && content !== '' && typeof content !== 'object') {
|
||||
return String(content);
|
||||
if (
|
||||
content !== undefined &&
|
||||
content !== null &&
|
||||
content !== '' &&
|
||||
typeof content !== 'object'
|
||||
) {
|
||||
return String(content)
|
||||
}
|
||||
|
||||
return '';
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
|
||||
for (const message of messages) {
|
||||
const textContent = extractTextContent(message.content);
|
||||
|
||||
const textContent = extractTextContent(message.content)
|
||||
|
||||
if (message.role === 'system') {
|
||||
systemInstruction += (systemInstruction ? '\n\n' : '') + textContent;
|
||||
systemInstruction += (systemInstruction ? '\n\n' : '') + textContent
|
||||
} else if (message.role === 'user') {
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: textContent }]
|
||||
});
|
||||
})
|
||||
} else if (message.role === 'assistant') {
|
||||
contents.push({
|
||||
role: 'model',
|
||||
parts: [{ text: textContent }]
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return { contents, systemInstruction };
|
||||
|
||||
return { contents, systemInstruction }
|
||||
}
|
||||
|
||||
// 转换 Gemini 响应到 OpenAI 格式
|
||||
function convertGeminiResponseToOpenAI(geminiResponse, model, stream = false) {
|
||||
if (stream) {
|
||||
// 处理流式响应 - 原样返回 SSE 数据
|
||||
return geminiResponse;
|
||||
return geminiResponse
|
||||
} else {
|
||||
// 非流式响应转换
|
||||
// 处理嵌套的 response 结构
|
||||
const actualResponse = geminiResponse.response || geminiResponse;
|
||||
|
||||
const actualResponse = geminiResponse.response || geminiResponse
|
||||
|
||||
if (actualResponse.candidates && actualResponse.candidates.length > 0) {
|
||||
const candidate = actualResponse.candidates[0];
|
||||
const content = candidate.content?.parts?.[0]?.text || '';
|
||||
const finishReason = candidate.finishReason?.toLowerCase() || 'stop';
|
||||
const candidate = actualResponse.candidates[0]
|
||||
const content = candidate.content?.parts?.[0]?.text || ''
|
||||
const finishReason = candidate.finishReason?.toLowerCase() || 'stop'
|
||||
|
||||
// 计算 token 使用量
|
||||
const usage = actualResponse.usageMetadata || {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: content
|
||||
},
|
||||
finish_reason: finishReason
|
||||
}],
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content
|
||||
},
|
||||
finish_reason: finishReason
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: usage.promptTokenCount,
|
||||
completion_tokens: usage.candidatesTokenCount,
|
||||
total_tokens: usage.totalTokenCount
|
||||
}
|
||||
};
|
||||
}
|
||||
} else {
|
||||
throw new Error('No response from Gemini');
|
||||
throw new Error('No response from Gemini')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI 兼容的聊天完成端点
|
||||
router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
const startTime = Date.now();
|
||||
let abortController = null;
|
||||
let account = null; // Declare account outside try block for error handling
|
||||
let accountSelection = null; // Declare accountSelection for error handling
|
||||
let sessionHash = null; // Declare sessionHash for error handling
|
||||
|
||||
const startTime = Date.now()
|
||||
let abortController = null
|
||||
let account = null // Declare account outside try block for error handling
|
||||
let accountSelection = null // Declare accountSelection for error handling
|
||||
let sessionHash = null // Declare sessionHash for error handling
|
||||
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
return res.status(403).json({
|
||||
@@ -182,25 +197,25 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
// 处理请求体结构 - 支持多种格式
|
||||
let requestBody = req.body;
|
||||
|
||||
let requestBody = req.body
|
||||
|
||||
// 如果请求体被包装在 body 字段中,解包它
|
||||
if (req.body.body && typeof req.body.body === 'object') {
|
||||
requestBody = req.body.body;
|
||||
requestBody = req.body.body
|
||||
}
|
||||
|
||||
|
||||
// 从 URL 路径中提取模型信息(如果存在)
|
||||
let urlModel = null;
|
||||
const urlPath = req.body?.config?.url || req.originalUrl || req.url;
|
||||
const modelMatch = urlPath.match(/\/([^/]+):(?:stream)?[Gg]enerateContent/);
|
||||
let urlModel = null
|
||||
const urlPath = req.body?.config?.url || req.originalUrl || req.url
|
||||
const modelMatch = urlPath.match(/\/([^/]+):(?:stream)?[Gg]enerateContent/)
|
||||
if (modelMatch) {
|
||||
urlModel = modelMatch[1];
|
||||
logger.debug(`Extracted model from URL: ${urlModel}`);
|
||||
urlModel = modelMatch[1]
|
||||
logger.debug(`Extracted model from URL: ${urlModel}`)
|
||||
}
|
||||
|
||||
|
||||
// 提取请求参数
|
||||
const {
|
||||
messages: requestMessages,
|
||||
@@ -209,19 +224,19 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
temperature = 0.7,
|
||||
max_tokens = 4096,
|
||||
stream = false
|
||||
} = requestBody;
|
||||
|
||||
} = requestBody
|
||||
|
||||
// 检查URL中是否包含stream标识
|
||||
const isStreamFromUrl = urlPath && urlPath.includes('streamGenerateContent');
|
||||
const actualStream = stream || isStreamFromUrl;
|
||||
const isStreamFromUrl = urlPath && urlPath.includes('streamGenerateContent')
|
||||
const actualStream = stream || isStreamFromUrl
|
||||
|
||||
// 优先使用 URL 中的模型,其次是请求体中的模型
|
||||
const model = urlModel || bodyModel;
|
||||
const model = urlModel || bodyModel
|
||||
|
||||
// 支持两种格式: OpenAI 的 messages 或 Gemini 的 contents
|
||||
let messages = requestMessages;
|
||||
let messages = requestMessages
|
||||
if (requestContents && Array.isArray(requestContents)) {
|
||||
messages = requestContents;
|
||||
messages = requestContents
|
||||
}
|
||||
|
||||
// 验证必需参数
|
||||
@@ -232,9 +247,9 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'invalid_request'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(model)) {
|
||||
@@ -244,13 +259,13 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_allowed'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 转换消息格式
|
||||
const { contents: geminiContents, systemInstruction } = convertMessagesToGemini(messages);
|
||||
|
||||
const { contents: geminiContents, systemInstruction } = convertMessagesToGemini(messages)
|
||||
|
||||
// 构建 Gemini 请求体
|
||||
const geminiRequestBody = {
|
||||
contents: geminiContents,
|
||||
@@ -259,24 +274,28 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
maxOutputTokens: max_tokens,
|
||||
candidateCount: 1
|
||||
}
|
||||
};
|
||||
|
||||
if (systemInstruction) {
|
||||
geminiRequestBody.systemInstruction = { parts: [{ text: systemInstruction }] };
|
||||
}
|
||||
|
||||
|
||||
if (systemInstruction) {
|
||||
geminiRequestBody.systemInstruction = { parts: [{ text: systemInstruction }] }
|
||||
}
|
||||
|
||||
// 生成会话哈希用于粘性会话
|
||||
sessionHash = generateSessionHash(req);
|
||||
|
||||
sessionHash = generateSessionHash(req)
|
||||
|
||||
// 选择可用的 Gemini 账户
|
||||
try {
|
||||
accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(apiKeyData, sessionHash, model);
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId);
|
||||
accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
sessionHash,
|
||||
model
|
||||
)
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId)
|
||||
} catch (error) {
|
||||
logger.error('Failed to select Gemini account:', error);
|
||||
account = null;
|
||||
logger.error('Failed to select Gemini account:', error)
|
||||
account = null
|
||||
}
|
||||
|
||||
|
||||
if (!account) {
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
@@ -284,35 +303,38 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
type: 'service_unavailable',
|
||||
code: 'service_unavailable'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`);
|
||||
|
||||
|
||||
logger.info(`Using Gemini account: ${account.id} for API key: ${apiKeyData.id}`)
|
||||
|
||||
// 标记账户被使用
|
||||
await geminiAccountService.markAccountUsed(account.id);
|
||||
|
||||
await geminiAccountService.markAccountUsed(account.id)
|
||||
|
||||
// 创建中止控制器
|
||||
abortController = new AbortController();
|
||||
|
||||
abortController = new AbortController()
|
||||
|
||||
// 处理客户端断开连接
|
||||
req.on('close', () => {
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
logger.info('Client disconnected, aborting Gemini request');
|
||||
abortController.abort();
|
||||
logger.info('Client disconnected, aborting Gemini request')
|
||||
abortController.abort()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 获取OAuth客户端
|
||||
const client = await geminiAccountService.getOauthClient(account.accessToken, account.refreshToken);
|
||||
const client = await geminiAccountService.getOauthClient(
|
||||
account.accessToken,
|
||||
account.refreshToken
|
||||
)
|
||||
if (actualStream) {
|
||||
// 流式响应
|
||||
logger.info('StreamGenerateContent request', {
|
||||
model: model,
|
||||
model,
|
||||
projectId: account.projectId,
|
||||
apiKeyId: apiKeyData.id
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
const streamResponse = await geminiAccountService.generateContentStream(
|
||||
client,
|
||||
{ model, request: geminiRequestBody },
|
||||
@@ -320,93 +342,101 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
account.projectId, // 使用有权限的项目ID
|
||||
apiKeyData.id, // 使用 API Key ID 作为 session ID
|
||||
abortController.signal // 传递中止信号
|
||||
);
|
||||
|
||||
)
|
||||
|
||||
// 设置流式响应头
|
||||
res.setHeader('Content-Type', 'text/event-stream');
|
||||
res.setHeader('Cache-Control', 'no-cache');
|
||||
res.setHeader('Connection', 'keep-alive');
|
||||
res.setHeader('X-Accel-Buffering', 'no');
|
||||
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
|
||||
// 处理流式响应,转换为 OpenAI 格式
|
||||
let buffer = '';
|
||||
|
||||
let buffer = ''
|
||||
|
||||
// 发送初始的空消息,符合 OpenAI 流式格式
|
||||
const initialChunk = {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: { role: 'assistant' },
|
||||
finish_reason: null
|
||||
}]
|
||||
};
|
||||
res.write(`data: ${JSON.stringify(initialChunk)}\n\n`);
|
||||
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { role: 'assistant' },
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(initialChunk)}\n\n`)
|
||||
|
||||
// 用于收集usage数据
|
||||
let totalUsage = {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
let usageReported = false;
|
||||
}
|
||||
const usageReported = false
|
||||
|
||||
streamResponse.on('data', (chunk) => {
|
||||
try {
|
||||
const chunkStr = chunk.toString();
|
||||
|
||||
const chunkStr = chunk.toString()
|
||||
|
||||
if (!chunkStr.trim()) {
|
||||
return;
|
||||
return
|
||||
}
|
||||
|
||||
buffer += chunkStr;
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || ''; // 保留最后一个不完整的行
|
||||
|
||||
|
||||
buffer += chunkStr
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || '' // 保留最后一个不完整的行
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.trim()) continue;
|
||||
|
||||
// 处理 SSE 格式
|
||||
let jsonData = line;
|
||||
if (line.startsWith('data: ')) {
|
||||
jsonData = line.substring(6).trim();
|
||||
if (!line.trim()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (!jsonData || jsonData === '[DONE]') continue;
|
||||
|
||||
|
||||
// 处理 SSE 格式
|
||||
let jsonData = line
|
||||
if (line.startsWith('data: ')) {
|
||||
jsonData = line.substring(6).trim()
|
||||
}
|
||||
|
||||
if (!jsonData || jsonData === '[DONE]') {
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonData);
|
||||
|
||||
const data = JSON.parse(jsonData)
|
||||
|
||||
// 捕获usage数据
|
||||
if (data.response?.usageMetadata) {
|
||||
totalUsage = data.response.usageMetadata;
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage);
|
||||
totalUsage = data.response.usageMetadata
|
||||
logger.debug('📊 Captured Gemini usage data:', totalUsage)
|
||||
}
|
||||
|
||||
|
||||
// 转换为 OpenAI 流式格式
|
||||
if (data.response?.candidates && data.response.candidates.length > 0) {
|
||||
const candidate = data.response.candidates[0];
|
||||
const content = candidate.content?.parts?.[0]?.text || '';
|
||||
const finishReason = candidate.finishReason;
|
||||
|
||||
const candidate = data.response.candidates[0]
|
||||
const content = candidate.content?.parts?.[0]?.text || ''
|
||||
const { finishReason } = candidate
|
||||
|
||||
// 只有当有内容或者是结束标记时才发送数据
|
||||
if (content || finishReason === 'STOP') {
|
||||
const openaiChunk = {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: content ? { content: content } : {},
|
||||
finish_reason: finishReason === 'STOP' ? 'stop' : null
|
||||
}]
|
||||
};
|
||||
|
||||
res.write(`data: ${JSON.stringify(openaiChunk)}\n\n`);
|
||||
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: content ? { content } : {},
|
||||
finish_reason: finishReason === 'STOP' ? 'stop' : null
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
res.write(`data: ${JSON.stringify(openaiChunk)}\n\n`)
|
||||
|
||||
// 如果结束了,添加 usage 信息并发送最终的 [DONE]
|
||||
if (finishReason === 'STOP') {
|
||||
// 如果有 usage 数据,添加到最后一个 chunk
|
||||
@@ -415,48 +445,50 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: 'stop'
|
||||
}],
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: 'stop'
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: data.response.usageMetadata.promptTokenCount || 0,
|
||||
completion_tokens: data.response.usageMetadata.candidatesTokenCount || 0,
|
||||
total_tokens: data.response.usageMetadata.totalTokenCount || 0
|
||||
}
|
||||
};
|
||||
res.write(`data: ${JSON.stringify(usageChunk)}\n\n`);
|
||||
}
|
||||
res.write(`data: ${JSON.stringify(usageChunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.write('data: [DONE]\n\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug('Error parsing JSON line:', e.message);
|
||||
logger.debug('Error parsing JSON line:', e.message)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Stream processing error:', error);
|
||||
logger.error('Stream processing error:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Stream error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
streamResponse.on('end', async () => {
|
||||
logger.info('Stream completed successfully');
|
||||
|
||||
logger.info('Stream completed successfully')
|
||||
|
||||
// 记录使用统计
|
||||
if (!usageReported && totalUsage.totalTokenCount > 0) {
|
||||
try {
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
await apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
totalUsage.promptTokenCount || 0,
|
||||
@@ -465,59 +497,60 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini stream usage - Input: ${totalUsage.promptTokenCount}, Output: ${totalUsage.candidatesTokenCount}, Total: ${totalUsage.totalTokenCount}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.write('data: [DONE]\n\n')
|
||||
}
|
||||
res.end();
|
||||
});
|
||||
|
||||
res.end()
|
||||
})
|
||||
|
||||
streamResponse.on('error', (error) => {
|
||||
logger.error('Stream error:', error);
|
||||
logger.error('Stream error:', error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: error.message || 'Stream error',
|
||||
type: 'api_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
} else {
|
||||
// 如果已经开始发送流数据,发送错误事件
|
||||
res.write(`data: {"error": {"message": "${error.message || 'Stream error'}"}}\n\n`);
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.end();
|
||||
res.write(`data: {"error": {"message": "${error.message || 'Stream error'}"}}\n\n`)
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} else {
|
||||
// 非流式响应
|
||||
logger.info('GenerateContent request', {
|
||||
model: model,
|
||||
model,
|
||||
projectId: account.projectId,
|
||||
apiKeyId: apiKeyData.id
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
const response = await geminiAccountService.generateContent(
|
||||
client,
|
||||
{ model, request: geminiRequestBody },
|
||||
null, // user_prompt_id
|
||||
account.projectId, // 使用有权限的项目ID
|
||||
apiKeyData.id // 使用 API Key ID 作为 session ID
|
||||
);
|
||||
|
||||
)
|
||||
|
||||
// 转换为 OpenAI 格式并返回
|
||||
const openaiResponse = convertGeminiResponseToOpenAI(response, model, false);
|
||||
|
||||
const openaiResponse = convertGeminiResponseToOpenAI(response, model, false)
|
||||
|
||||
// 记录使用统计
|
||||
if (openaiResponse.usage) {
|
||||
try {
|
||||
const apiKeyService = require('../services/apiKeyService');
|
||||
const apiKeyService = require('../services/apiKeyService')
|
||||
await apiKeyService.recordUsage(
|
||||
apiKeyData.id,
|
||||
openaiResponse.usage.prompt_tokens || 0,
|
||||
@@ -526,53 +559,55 @@ router.post('/v1/chat/completions', authenticateApiKey, async (req, res) => {
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
account.id
|
||||
);
|
||||
logger.info(`📊 Recorded Gemini usage - Input: ${openaiResponse.usage.prompt_tokens}, Output: ${openaiResponse.usage.completion_tokens}, Total: ${openaiResponse.usage.total_tokens}`);
|
||||
)
|
||||
logger.info(
|
||||
`📊 Recorded Gemini usage - Input: ${openaiResponse.usage.prompt_tokens}, Output: ${openaiResponse.usage.completion_tokens}, Total: ${openaiResponse.usage.total_tokens}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to record Gemini usage:', error);
|
||||
logger.error('Failed to record Gemini usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
res.json(openaiResponse);
|
||||
|
||||
res.json(openaiResponse)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`OpenAI-Gemini request completed in ${duration}ms`);
|
||||
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`OpenAI-Gemini request completed in ${duration}ms`)
|
||||
} catch (error) {
|
||||
logger.error('OpenAI-Gemini request error:', error);
|
||||
|
||||
logger.error('OpenAI-Gemini request error:', error)
|
||||
|
||||
// 处理速率限制
|
||||
if (error.status === 429) {
|
||||
if (req.apiKey && account && accountSelection) {
|
||||
await unifiedGeminiScheduler.markAccountRateLimited(account.id, 'gemini', sessionHash);
|
||||
await unifiedGeminiScheduler.markAccountRateLimited(account.id, 'gemini', sessionHash)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 返回 OpenAI 格式的错误响应
|
||||
const status = error.status || 500;
|
||||
const status = error.status || 500
|
||||
const errorResponse = {
|
||||
error: error.error || {
|
||||
message: error.message || 'Internal server error',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
};
|
||||
|
||||
res.status(status).json(errorResponse);
|
||||
}
|
||||
|
||||
res.status(status).json(errorResponse)
|
||||
} finally {
|
||||
// 清理资源
|
||||
if (abortController) {
|
||||
abortController = null;
|
||||
abortController = null
|
||||
}
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// OpenAI 兼容的模型列表端点
|
||||
router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
return res.status(403).json({
|
||||
@@ -581,23 +616,27 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 选择账户获取模型列表
|
||||
let account = null;
|
||||
let account = null
|
||||
try {
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(apiKeyData, null, null);
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId);
|
||||
const accountSelection = await unifiedGeminiScheduler.selectAccountForApiKey(
|
||||
apiKeyData,
|
||||
null,
|
||||
null
|
||||
)
|
||||
account = await geminiAccountService.getAccount(accountSelection.accountId)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error);
|
||||
logger.warn('Failed to select Gemini account for models endpoint:', error)
|
||||
}
|
||||
|
||||
let models = [];
|
||||
|
||||
|
||||
let models = []
|
||||
|
||||
if (account) {
|
||||
// 获取实际的模型列表
|
||||
models = await getAvailableModels(account.accessToken, account.proxy);
|
||||
models = await getAvailableModels(account.accessToken, account.proxy)
|
||||
} else {
|
||||
// 返回默认模型列表
|
||||
models = [
|
||||
@@ -607,37 +646,37 @@ router.get('/v1/models', authenticateApiKey, async (req, res) => {
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
owned_by: 'google'
|
||||
}
|
||||
];
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
// 如果启用了模型限制,过滤模型列表
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels.length > 0) {
|
||||
models = models.filter(model => apiKeyData.restrictedModels.includes(model.id));
|
||||
models = models.filter((model) => apiKeyData.restrictedModels.includes(model.id))
|
||||
}
|
||||
|
||||
|
||||
res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get OpenAI-Gemini models:', error);
|
||||
logger.error('Failed to get OpenAI-Gemini models:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
// OpenAI 兼容的模型详情端点
|
||||
router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
try {
|
||||
const apiKeyData = req.apiKey;
|
||||
const modelId = req.params.model;
|
||||
|
||||
const apiKeyData = req.apiKey
|
||||
const modelId = req.params.model
|
||||
|
||||
// 检查权限
|
||||
if (!checkPermissions(apiKeyData, 'gemini')) {
|
||||
return res.status(403).json({
|
||||
@@ -646,9 +685,9 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'permission_denied',
|
||||
code: 'permission_denied'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels.length > 0) {
|
||||
if (!apiKeyData.restrictedModels.includes(modelId)) {
|
||||
@@ -658,10 +697,10 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
type: 'invalid_request_error',
|
||||
code: 'model_not_found'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 返回模型信息
|
||||
res.json({
|
||||
id: modelId,
|
||||
@@ -671,18 +710,18 @@ router.get('/v1/models/:model', authenticateApiKey, async (req, res) => {
|
||||
permission: [],
|
||||
root: modelId,
|
||||
parent: null
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to get model details:', error);
|
||||
logger.error('Failed to get model details:', error)
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve model details',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
return undefined
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,158 +1,157 @@
|
||||
const express = require('express');
|
||||
const bcrypt = require('bcryptjs');
|
||||
const crypto = require('crypto');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const express = require('express')
|
||||
const bcrypt = require('bcryptjs')
|
||||
const crypto = require('crypto')
|
||||
const path = require('path')
|
||||
const fs = require('fs')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
|
||||
const router = express.Router();
|
||||
const router = express.Router()
|
||||
|
||||
// 🏠 服务静态文件
|
||||
router.use('/assets', express.static(path.join(__dirname, '../../web/assets')));
|
||||
router.use('/assets', express.static(path.join(__dirname, '../../web/assets')))
|
||||
|
||||
// 🌐 页面路由重定向到新版 admin-spa
|
||||
router.get('/', (req, res) => {
|
||||
res.redirect(301, '/admin-next/api-stats');
|
||||
});
|
||||
res.redirect(301, '/admin-next/api-stats')
|
||||
})
|
||||
|
||||
// 🔐 管理员登录
|
||||
router.post('/auth/login', async (req, res) => {
|
||||
try {
|
||||
const { username, password } = req.body;
|
||||
const { username, password } = req.body
|
||||
|
||||
if (!username || !password) {
|
||||
return res.status(400).json({
|
||||
error: 'Missing credentials',
|
||||
message: 'Username and password are required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 从Redis获取管理员信息
|
||||
let adminData = await redis.getSession('admin_credentials');
|
||||
|
||||
let adminData = await redis.getSession('admin_credentials')
|
||||
|
||||
// 如果Redis中没有管理员凭据,尝试从init.json重新加载
|
||||
if (!adminData || Object.keys(adminData).length === 0) {
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json');
|
||||
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json')
|
||||
|
||||
if (fs.existsSync(initFilePath)) {
|
||||
try {
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'));
|
||||
const saltRounds = 10;
|
||||
const passwordHash = await bcrypt.hash(initData.adminPassword, saltRounds);
|
||||
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'))
|
||||
const saltRounds = 10
|
||||
const passwordHash = await bcrypt.hash(initData.adminPassword, saltRounds)
|
||||
|
||||
adminData = {
|
||||
username: initData.adminUsername,
|
||||
passwordHash: passwordHash,
|
||||
passwordHash,
|
||||
createdAt: initData.initializedAt || new Date().toISOString(),
|
||||
lastLogin: null,
|
||||
updatedAt: initData.updatedAt || null
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 重新存储到Redis,不设置过期时间
|
||||
await redis.getClient().hset('session:admin_credentials', adminData);
|
||||
|
||||
logger.info('✅ Admin credentials reloaded from init.json');
|
||||
await redis.getClient().hset('session:admin_credentials', adminData)
|
||||
|
||||
logger.info('✅ Admin credentials reloaded from init.json')
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to reload admin credentials:', error);
|
||||
logger.error('❌ Failed to reload admin credentials:', error)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid credentials',
|
||||
message: 'Invalid username or password'
|
||||
});
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid credentials',
|
||||
message: 'Invalid username or password'
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 验证用户名和密码
|
||||
const isValidUsername = adminData.username === username;
|
||||
const isValidPassword = await bcrypt.compare(password, adminData.passwordHash);
|
||||
const isValidUsername = adminData.username === username
|
||||
const isValidPassword = await bcrypt.compare(password, adminData.passwordHash)
|
||||
|
||||
if (!isValidUsername || !isValidPassword) {
|
||||
logger.security(`🔒 Failed login attempt for username: ${username}`);
|
||||
logger.security(`🔒 Failed login attempt for username: ${username}`)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid credentials',
|
||||
message: 'Invalid username or password'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 生成会话token
|
||||
const sessionId = crypto.randomBytes(32).toString('hex');
|
||||
|
||||
const sessionId = crypto.randomBytes(32).toString('hex')
|
||||
|
||||
// 存储会话
|
||||
const sessionData = {
|
||||
username: adminData.username,
|
||||
loginTime: new Date().toISOString(),
|
||||
lastActivity: new Date().toISOString()
|
||||
};
|
||||
|
||||
await redis.setSession(sessionId, sessionData, config.security.adminSessionTimeout);
|
||||
|
||||
}
|
||||
|
||||
await redis.setSession(sessionId, sessionData, config.security.adminSessionTimeout)
|
||||
|
||||
// 不再更新 Redis 中的最后登录时间,因为 Redis 只是缓存
|
||||
// init.json 是唯一真实数据源
|
||||
|
||||
logger.success(`🔐 Admin login successful: ${username}`);
|
||||
logger.success(`🔐 Admin login successful: ${username}`)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
token: sessionId,
|
||||
expiresIn: config.security.adminSessionTimeout,
|
||||
username: adminData.username // 返回真实用户名
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Login error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Login error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Login failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🚪 管理员登出
|
||||
router.post('/auth/logout', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (token) {
|
||||
await redis.deleteSession(token);
|
||||
logger.success('🚪 Admin logout successful');
|
||||
await redis.deleteSession(token)
|
||||
logger.success('🚪 Admin logout successful')
|
||||
}
|
||||
|
||||
res.json({ success: true, message: 'Logout successful' });
|
||||
return res.json({ success: true, message: 'Logout successful' })
|
||||
} catch (error) {
|
||||
logger.error('❌ Logout error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Logout error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Logout failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🔑 修改账户信息
|
||||
router.post('/auth/change-password', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (!token) {
|
||||
return res.status(401).json({
|
||||
error: 'No token provided',
|
||||
message: 'Authentication required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const { newUsername, currentPassword, newPassword } = req.body;
|
||||
const { newUsername, currentPassword, newPassword } = req.body
|
||||
|
||||
if (!currentPassword || !newPassword) {
|
||||
return res.status(400).json({
|
||||
error: 'Missing required fields',
|
||||
message: 'Current password and new password are required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证新密码长度
|
||||
@@ -160,189 +159,186 @@ router.post('/auth/change-password', async (req, res) => {
|
||||
return res.status(400).json({
|
||||
error: 'Password too short',
|
||||
message: 'New password must be at least 8 characters long'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取当前会话
|
||||
const sessionData = await redis.getSession(token);
|
||||
const sessionData = await redis.getSession(token)
|
||||
if (!sessionData) {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid token',
|
||||
message: 'Session expired or invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取当前管理员信息
|
||||
const adminData = await redis.getSession('admin_credentials');
|
||||
const adminData = await redis.getSession('admin_credentials')
|
||||
if (!adminData) {
|
||||
return res.status(500).json({
|
||||
error: 'Admin data not found',
|
||||
message: 'Administrator credentials not found'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 验证当前密码
|
||||
const isValidPassword = await bcrypt.compare(currentPassword, adminData.passwordHash);
|
||||
const isValidPassword = await bcrypt.compare(currentPassword, adminData.passwordHash)
|
||||
if (!isValidPassword) {
|
||||
logger.security(`🔒 Invalid current password attempt for user: ${sessionData.username}`);
|
||||
logger.security(`🔒 Invalid current password attempt for user: ${sessionData.username}`)
|
||||
return res.status(401).json({
|
||||
error: 'Invalid current password',
|
||||
message: 'Current password is incorrect'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 准备更新的数据
|
||||
const updatedUsername = newUsername && newUsername.trim() ? newUsername.trim() : adminData.username;
|
||||
|
||||
const updatedUsername =
|
||||
newUsername && newUsername.trim() ? newUsername.trim() : adminData.username
|
||||
|
||||
// 先更新 init.json(唯一真实数据源)
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json');
|
||||
const initFilePath = path.join(__dirname, '../../data/init.json')
|
||||
if (!fs.existsSync(initFilePath)) {
|
||||
return res.status(500).json({
|
||||
error: 'Configuration file not found',
|
||||
message: 'init.json file is missing'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'));
|
||||
const initData = JSON.parse(fs.readFileSync(initFilePath, 'utf8'))
|
||||
// const oldData = { ...initData }; // 备份旧数据
|
||||
|
||||
|
||||
// 更新 init.json
|
||||
initData.adminUsername = updatedUsername;
|
||||
initData.adminPassword = newPassword; // 保存明文密码到init.json
|
||||
initData.updatedAt = new Date().toISOString();
|
||||
|
||||
initData.adminUsername = updatedUsername
|
||||
initData.adminPassword = newPassword // 保存明文密码到init.json
|
||||
initData.updatedAt = new Date().toISOString()
|
||||
|
||||
// 先写入文件(如果失败则不会影响 Redis)
|
||||
fs.writeFileSync(initFilePath, JSON.stringify(initData, null, 2));
|
||||
|
||||
fs.writeFileSync(initFilePath, JSON.stringify(initData, null, 2))
|
||||
|
||||
// 文件写入成功后,更新 Redis 缓存
|
||||
const saltRounds = 10;
|
||||
const newPasswordHash = await bcrypt.hash(newPassword, saltRounds);
|
||||
|
||||
const saltRounds = 10
|
||||
const newPasswordHash = await bcrypt.hash(newPassword, saltRounds)
|
||||
|
||||
const updatedAdminData = {
|
||||
username: updatedUsername,
|
||||
passwordHash: newPasswordHash,
|
||||
createdAt: adminData.createdAt,
|
||||
lastLogin: adminData.lastLogin,
|
||||
updatedAt: new Date().toISOString()
|
||||
};
|
||||
|
||||
await redis.setSession('admin_credentials', updatedAdminData);
|
||||
|
||||
}
|
||||
|
||||
await redis.setSession('admin_credentials', updatedAdminData)
|
||||
} catch (fileError) {
|
||||
logger.error('❌ Failed to update init.json:', fileError);
|
||||
logger.error('❌ Failed to update init.json:', fileError)
|
||||
return res.status(500).json({
|
||||
error: 'Update failed',
|
||||
message: 'Failed to update configuration file'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 清除当前会话(强制用户重新登录)
|
||||
await redis.deleteSession(token);
|
||||
await redis.deleteSession(token)
|
||||
|
||||
logger.success(`🔐 Admin password changed successfully for user: ${updatedUsername}`);
|
||||
logger.success(`🔐 Admin password changed successfully for user: ${updatedUsername}`)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
message: 'Password changed successfully. Please login again.',
|
||||
newUsername: updatedUsername
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Change password error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Change password error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Change password failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 👤 获取当前用户信息
|
||||
router.get('/auth/user', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (!token) {
|
||||
return res.status(401).json({
|
||||
error: 'No token provided',
|
||||
message: 'Authentication required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取当前会话
|
||||
const sessionData = await redis.getSession(token);
|
||||
const sessionData = await redis.getSession(token)
|
||||
if (!sessionData) {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid token',
|
||||
message: 'Session expired or invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 获取管理员信息
|
||||
const adminData = await redis.getSession('admin_credentials');
|
||||
const adminData = await redis.getSession('admin_credentials')
|
||||
if (!adminData) {
|
||||
return res.status(500).json({
|
||||
error: 'Admin data not found',
|
||||
message: 'Administrator credentials not found'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
user: {
|
||||
username: adminData.username,
|
||||
loginTime: sessionData.loginTime,
|
||||
lastActivity: sessionData.lastActivity
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Get user info error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Get user info error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Get user info failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
// 🔄 刷新token
|
||||
router.post('/auth/refresh', async (req, res) => {
|
||||
try {
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken;
|
||||
|
||||
const token = req.headers['authorization']?.replace('Bearer ', '') || req.cookies?.adminToken
|
||||
|
||||
if (!token) {
|
||||
return res.status(401).json({
|
||||
error: 'No token provided',
|
||||
message: 'Authentication required'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
const sessionData = await redis.getSession(token);
|
||||
|
||||
const sessionData = await redis.getSession(token)
|
||||
|
||||
if (!sessionData) {
|
||||
return res.status(401).json({
|
||||
error: 'Invalid token',
|
||||
message: 'Session expired or invalid'
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// 更新最后活动时间
|
||||
sessionData.lastActivity = new Date().toISOString();
|
||||
await redis.setSession(token, sessionData, config.security.adminSessionTimeout);
|
||||
sessionData.lastActivity = new Date().toISOString()
|
||||
await redis.setSession(token, sessionData, config.security.adminSessionTimeout)
|
||||
|
||||
res.json({
|
||||
return res.json({
|
||||
success: true,
|
||||
token: token,
|
||||
token,
|
||||
expiresIn: config.security.adminSessionTimeout
|
||||
});
|
||||
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('❌ Token refresh error:', error);
|
||||
res.status(500).json({
|
||||
logger.error('❌ Token refresh error:', error)
|
||||
return res.status(500).json({
|
||||
error: 'Token refresh failed',
|
||||
message: 'Internal server error'
|
||||
});
|
||||
})
|
||||
}
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = router;
|
||||
module.exports = router
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const logger = require('../utils/logger');
|
||||
const redis = require('../models/redis');
|
||||
const { v4: uuidv4 } = require('uuid')
|
||||
const logger = require('../utils/logger')
|
||||
const redis = require('../models/redis')
|
||||
|
||||
class AccountGroupService {
|
||||
constructor() {
|
||||
this.GROUPS_KEY = 'account_groups';
|
||||
this.GROUP_PREFIX = 'account_group:';
|
||||
this.GROUP_MEMBERS_PREFIX = 'account_group_members:';
|
||||
this.GROUPS_KEY = 'account_groups'
|
||||
this.GROUP_PREFIX = 'account_group:'
|
||||
this.GROUP_MEMBERS_PREFIX = 'account_group_members:'
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -19,22 +19,22 @@ class AccountGroupService {
|
||||
*/
|
||||
async createGroup(groupData) {
|
||||
try {
|
||||
const { name, platform, description = '' } = groupData;
|
||||
|
||||
const { name, platform, description = '' } = groupData
|
||||
|
||||
// 验证必填字段
|
||||
if (!name || !platform) {
|
||||
throw new Error('分组名称和平台类型为必填项');
|
||||
throw new Error('分组名称和平台类型为必填项')
|
||||
}
|
||||
|
||||
|
||||
// 验证平台类型
|
||||
if (!['claude', 'gemini'].includes(platform)) {
|
||||
throw new Error('平台类型必须是 claude 或 gemini');
|
||||
throw new Error('平台类型必须是 claude 或 gemini')
|
||||
}
|
||||
|
||||
const client = redis.getClientSafe();
|
||||
const groupId = uuidv4();
|
||||
const now = new Date().toISOString();
|
||||
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const groupId = uuidv4()
|
||||
const now = new Date().toISOString()
|
||||
|
||||
const group = {
|
||||
id: groupId,
|
||||
name,
|
||||
@@ -42,20 +42,20 @@ class AccountGroupService {
|
||||
description,
|
||||
createdAt: now,
|
||||
updatedAt: now
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 保存分组数据
|
||||
await client.hmset(`${this.GROUP_PREFIX}${groupId}`, group);
|
||||
|
||||
await client.hmset(`${this.GROUP_PREFIX}${groupId}`, group)
|
||||
|
||||
// 添加到分组集合
|
||||
await client.sadd(this.GROUPS_KEY, groupId);
|
||||
|
||||
logger.success(`✅ 创建账户分组成功: ${name} (${platform})`);
|
||||
|
||||
return group;
|
||||
await client.sadd(this.GROUPS_KEY, groupId)
|
||||
|
||||
logger.success(`✅ 创建账户分组成功: ${name} (${platform})`)
|
||||
|
||||
return group
|
||||
} catch (error) {
|
||||
logger.error('❌ 创建账户分组失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 创建账户分组失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,46 +67,46 @@ class AccountGroupService {
|
||||
*/
|
||||
async updateGroup(groupId, updates) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const groupKey = `${this.GROUP_PREFIX}${groupId}`;
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const groupKey = `${this.GROUP_PREFIX}${groupId}`
|
||||
|
||||
// 检查分组是否存在
|
||||
const exists = await client.exists(groupKey);
|
||||
const exists = await client.exists(groupKey)
|
||||
if (!exists) {
|
||||
throw new Error('分组不存在');
|
||||
throw new Error('分组不存在')
|
||||
}
|
||||
|
||||
|
||||
// 获取现有分组数据
|
||||
const existingGroup = await client.hgetall(groupKey);
|
||||
|
||||
const existingGroup = await client.hgetall(groupKey)
|
||||
|
||||
// 不允许修改平台类型
|
||||
if (updates.platform && updates.platform !== existingGroup.platform) {
|
||||
throw new Error('不能修改分组的平台类型');
|
||||
throw new Error('不能修改分组的平台类型')
|
||||
}
|
||||
|
||||
|
||||
// 准备更新数据
|
||||
const updateData = {
|
||||
...updates,
|
||||
updatedAt: new Date().toISOString()
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 移除不允许修改的字段
|
||||
delete updateData.id;
|
||||
delete updateData.platform;
|
||||
delete updateData.createdAt;
|
||||
|
||||
delete updateData.id
|
||||
delete updateData.platform
|
||||
delete updateData.createdAt
|
||||
|
||||
// 更新分组
|
||||
await client.hmset(groupKey, updateData);
|
||||
|
||||
await client.hmset(groupKey, updateData)
|
||||
|
||||
// 返回更新后的完整数据
|
||||
const updatedGroup = await client.hgetall(groupKey);
|
||||
|
||||
logger.success(`✅ 更新账户分组成功: ${updatedGroup.name}`);
|
||||
|
||||
return updatedGroup;
|
||||
const updatedGroup = await client.hgetall(groupKey)
|
||||
|
||||
logger.success(`✅ 更新账户分组成功: ${updatedGroup.name}`)
|
||||
|
||||
return updatedGroup
|
||||
} catch (error) {
|
||||
logger.error('❌ 更新账户分组失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 更新账户分组失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,37 +116,37 @@ class AccountGroupService {
|
||||
*/
|
||||
async deleteGroup(groupId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 检查分组是否存在
|
||||
const group = await this.getGroup(groupId);
|
||||
const group = await this.getGroup(groupId)
|
||||
if (!group) {
|
||||
throw new Error('分组不存在');
|
||||
throw new Error('分组不存在')
|
||||
}
|
||||
|
||||
|
||||
// 检查分组是否为空
|
||||
const members = await this.getGroupMembers(groupId);
|
||||
const members = await this.getGroupMembers(groupId)
|
||||
if (members.length > 0) {
|
||||
throw new Error('分组内还有账户,无法删除');
|
||||
throw new Error('分组内还有账户,无法删除')
|
||||
}
|
||||
|
||||
|
||||
// 检查是否有API Key绑定此分组
|
||||
const boundApiKeys = await this.getApiKeysUsingGroup(groupId);
|
||||
const boundApiKeys = await this.getApiKeysUsingGroup(groupId)
|
||||
if (boundApiKeys.length > 0) {
|
||||
throw new Error('还有API Key使用此分组,无法删除');
|
||||
throw new Error('还有API Key使用此分组,无法删除')
|
||||
}
|
||||
|
||||
|
||||
// 删除分组数据
|
||||
await client.del(`${this.GROUP_PREFIX}${groupId}`);
|
||||
await client.del(`${this.GROUP_MEMBERS_PREFIX}${groupId}`);
|
||||
|
||||
await client.del(`${this.GROUP_PREFIX}${groupId}`)
|
||||
await client.del(`${this.GROUP_MEMBERS_PREFIX}${groupId}`)
|
||||
|
||||
// 从分组集合中移除
|
||||
await client.srem(this.GROUPS_KEY, groupId);
|
||||
|
||||
logger.success(`✅ 删除账户分组成功: ${group.name}`);
|
||||
await client.srem(this.GROUPS_KEY, groupId)
|
||||
|
||||
logger.success(`✅ 删除账户分组成功: ${group.name}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ 删除账户分组失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 删除账户分组失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,23 +157,23 @@ class AccountGroupService {
|
||||
*/
|
||||
async getGroup(groupId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const groupData = await client.hgetall(`${this.GROUP_PREFIX}${groupId}`);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const groupData = await client.hgetall(`${this.GROUP_PREFIX}${groupId}`)
|
||||
|
||||
if (!groupData || Object.keys(groupData).length === 0) {
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
|
||||
// 获取成员数量
|
||||
const memberCount = await client.scard(`${this.GROUP_MEMBERS_PREFIX}${groupId}`);
|
||||
|
||||
const memberCount = await client.scard(`${this.GROUP_MEMBERS_PREFIX}${groupId}`)
|
||||
|
||||
return {
|
||||
...groupData,
|
||||
memberCount: memberCount || 0
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取分组详情失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 获取分组详情失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,27 +184,27 @@ class AccountGroupService {
|
||||
*/
|
||||
async getAllGroups(platform = null) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const groupIds = await client.smembers(this.GROUPS_KEY);
|
||||
|
||||
const groups = [];
|
||||
const client = redis.getClientSafe()
|
||||
const groupIds = await client.smembers(this.GROUPS_KEY)
|
||||
|
||||
const groups = []
|
||||
for (const groupId of groupIds) {
|
||||
const group = await this.getGroup(groupId);
|
||||
const group = await this.getGroup(groupId)
|
||||
if (group) {
|
||||
// 如果指定了平台,进行筛选
|
||||
if (!platform || group.platform === platform) {
|
||||
groups.push(group);
|
||||
groups.push(group)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 按创建时间倒序排序
|
||||
groups.sort((a, b) => new Date(b.createdAt) - new Date(a.createdAt));
|
||||
|
||||
return groups;
|
||||
groups.sort((a, b) => new Date(b.createdAt) - new Date(a.createdAt))
|
||||
|
||||
return groups
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取分组列表失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 获取分组列表失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,27 +216,28 @@ class AccountGroupService {
|
||||
*/
|
||||
async addAccountToGroup(accountId, groupId, accountPlatform) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 获取分组信息
|
||||
const group = await this.getGroup(groupId);
|
||||
const group = await this.getGroup(groupId)
|
||||
if (!group) {
|
||||
throw new Error('分组不存在');
|
||||
throw new Error('分组不存在')
|
||||
}
|
||||
|
||||
|
||||
// 验证平台一致性 (Claude和Claude Console视为同一平台)
|
||||
const normalizedAccountPlatform = accountPlatform === 'claude-console' ? 'claude' : accountPlatform;
|
||||
const normalizedAccountPlatform =
|
||||
accountPlatform === 'claude-console' ? 'claude' : accountPlatform
|
||||
if (normalizedAccountPlatform !== group.platform) {
|
||||
throw new Error('账户平台与分组平台不匹配');
|
||||
throw new Error('账户平台与分组平台不匹配')
|
||||
}
|
||||
|
||||
|
||||
// 添加到分组成员集合
|
||||
await client.sadd(`${this.GROUP_MEMBERS_PREFIX}${groupId}`, accountId);
|
||||
|
||||
logger.success(`✅ 添加账户到分组成功: ${accountId} -> ${group.name}`);
|
||||
await client.sadd(`${this.GROUP_MEMBERS_PREFIX}${groupId}`, accountId)
|
||||
|
||||
logger.success(`✅ 添加账户到分组成功: ${accountId} -> ${group.name}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ 添加账户到分组失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 添加账户到分组失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,15 +248,15 @@ class AccountGroupService {
|
||||
*/
|
||||
async removeAccountFromGroup(accountId, groupId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 从分组成员集合中移除
|
||||
await client.srem(`${this.GROUP_MEMBERS_PREFIX}${groupId}`, accountId);
|
||||
|
||||
logger.success(`✅ 从分组移除账户成功: ${accountId}`);
|
||||
await client.srem(`${this.GROUP_MEMBERS_PREFIX}${groupId}`, accountId)
|
||||
|
||||
logger.success(`✅ 从分组移除账户成功: ${accountId}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ 从分组移除账户失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 从分组移除账户失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,12 +267,12 @@ class AccountGroupService {
|
||||
*/
|
||||
async getGroupMembers(groupId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const members = await client.smembers(`${this.GROUP_MEMBERS_PREFIX}${groupId}`);
|
||||
return members || [];
|
||||
const client = redis.getClientSafe()
|
||||
const members = await client.smembers(`${this.GROUP_MEMBERS_PREFIX}${groupId}`)
|
||||
return members || []
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取分组成员失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 获取分组成员失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,11 +283,11 @@ class AccountGroupService {
|
||||
*/
|
||||
async isGroupEmpty(groupId) {
|
||||
try {
|
||||
const members = await this.getGroupMembers(groupId);
|
||||
return members.length === 0;
|
||||
const members = await this.getGroupMembers(groupId)
|
||||
return members.length === 0
|
||||
} catch (error) {
|
||||
logger.error('❌ 检查分组是否为空失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 检查分组是否为空失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,29 +298,30 @@ class AccountGroupService {
|
||||
*/
|
||||
async getApiKeysUsingGroup(groupId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const groupKey = `group:${groupId}`;
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const groupKey = `group:${groupId}`
|
||||
|
||||
// 获取所有API Key
|
||||
const apiKeyIds = await client.smembers('api_keys');
|
||||
const boundApiKeys = [];
|
||||
|
||||
const apiKeyIds = await client.smembers('api_keys')
|
||||
const boundApiKeys = []
|
||||
|
||||
for (const keyId of apiKeyIds) {
|
||||
const keyData = await client.hgetall(`api_key:${keyId}`);
|
||||
if (keyData &&
|
||||
(keyData.claudeAccountId === groupKey ||
|
||||
keyData.geminiAccountId === groupKey)) {
|
||||
const keyData = await client.hgetall(`api_key:${keyId}`)
|
||||
if (
|
||||
keyData &&
|
||||
(keyData.claudeAccountId === groupKey || keyData.geminiAccountId === groupKey)
|
||||
) {
|
||||
boundApiKeys.push({
|
||||
id: keyId,
|
||||
name: keyData.name
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return boundApiKeys;
|
||||
|
||||
return boundApiKeys
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取使用分组的API Key失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 获取使用分组的API Key失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -330,22 +332,22 @@ class AccountGroupService {
|
||||
*/
|
||||
async getAccountGroup(accountId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const allGroupIds = await client.smembers(this.GROUPS_KEY);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const allGroupIds = await client.smembers(this.GROUPS_KEY)
|
||||
|
||||
for (const groupId of allGroupIds) {
|
||||
const isMember = await client.sismember(`${this.GROUP_MEMBERS_PREFIX}${groupId}`, accountId);
|
||||
const isMember = await client.sismember(`${this.GROUP_MEMBERS_PREFIX}${groupId}`, accountId)
|
||||
if (isMember) {
|
||||
return await this.getGroup(groupId);
|
||||
return await this.getGroup(groupId)
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
return null
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取账户所属分组失败:', error);
|
||||
throw error;
|
||||
logger.error('❌ 获取账户所属分组失败:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new AccountGroupService();
|
||||
module.exports = new AccountGroupService()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const crypto = require('crypto');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const config = require('../../config/config');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const crypto = require('crypto')
|
||||
const { v4: uuidv4 } = require('uuid')
|
||||
const config = require('../../config/config')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class ApiKeyService {
|
||||
constructor() {
|
||||
this.prefix = config.security.apiKeyPrefix;
|
||||
this.prefix = config.security.apiKeyPrefix
|
||||
}
|
||||
|
||||
// 🔑 生成新的API Key
|
||||
@@ -30,13 +30,13 @@ class ApiKeyService {
|
||||
allowedClients = [],
|
||||
dailyCostLimit = 0,
|
||||
tags = []
|
||||
} = options;
|
||||
} = options
|
||||
|
||||
// 生成简单的API Key (64字符十六进制)
|
||||
const apiKey = `${this.prefix}${this._generateSecretKey()}`;
|
||||
const keyId = uuidv4();
|
||||
const hashedKey = this._hashApiKey(apiKey);
|
||||
|
||||
const apiKey = `${this.prefix}${this._generateSecretKey()}`
|
||||
const keyId = uuidv4()
|
||||
const hashedKey = this._hashApiKey(apiKey)
|
||||
|
||||
const keyData = {
|
||||
id: keyId,
|
||||
name,
|
||||
@@ -61,13 +61,13 @@ class ApiKeyService {
|
||||
lastUsedAt: '',
|
||||
expiresAt: expiresAt || '',
|
||||
createdBy: 'admin' // 可以根据需要扩展用户系统
|
||||
};
|
||||
}
|
||||
|
||||
// 保存API Key数据并建立哈希映射
|
||||
await redis.setApiKey(keyId, keyData, hashedKey);
|
||||
|
||||
logger.success(`🔑 Generated new API key: ${name} (${keyId})`);
|
||||
|
||||
await redis.setApiKey(keyId, keyData, hashedKey)
|
||||
|
||||
logger.success(`🔑 Generated new API key: ${name} (${keyId})`)
|
||||
|
||||
return {
|
||||
id: keyId,
|
||||
apiKey, // 只在创建时返回完整的key
|
||||
@@ -91,69 +91,69 @@ class ApiKeyService {
|
||||
createdAt: keyData.createdAt,
|
||||
expiresAt: keyData.expiresAt,
|
||||
createdBy: keyData.createdBy
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 验证API Key
|
||||
// 🔍 验证API Key
|
||||
async validateApiKey(apiKey) {
|
||||
try {
|
||||
if (!apiKey || !apiKey.startsWith(this.prefix)) {
|
||||
return { valid: false, error: 'Invalid API key format' };
|
||||
return { valid: false, error: 'Invalid API key format' }
|
||||
}
|
||||
|
||||
// 计算API Key的哈希值
|
||||
const hashedKey = this._hashApiKey(apiKey);
|
||||
|
||||
const hashedKey = this._hashApiKey(apiKey)
|
||||
|
||||
// 通过哈希值直接查找API Key(性能优化)
|
||||
const keyData = await redis.findApiKeyByHash(hashedKey);
|
||||
|
||||
const keyData = await redis.findApiKeyByHash(hashedKey)
|
||||
|
||||
if (!keyData) {
|
||||
return { valid: false, error: 'API key not found' };
|
||||
return { valid: false, error: 'API key not found' }
|
||||
}
|
||||
|
||||
// 检查是否激活
|
||||
if (keyData.isActive !== 'true') {
|
||||
return { valid: false, error: 'API key is disabled' };
|
||||
return { valid: false, error: 'API key is disabled' }
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if (keyData.expiresAt && new Date() > new Date(keyData.expiresAt)) {
|
||||
return { valid: false, error: 'API key has expired' };
|
||||
return { valid: false, error: 'API key has expired' }
|
||||
}
|
||||
|
||||
// 获取使用统计(供返回数据使用)
|
||||
const usage = await redis.getUsageStats(keyData.id);
|
||||
|
||||
const usage = await redis.getUsageStats(keyData.id)
|
||||
|
||||
// 获取当日费用统计
|
||||
const dailyCost = await redis.getDailyCost(keyData.id);
|
||||
const dailyCost = await redis.getDailyCost(keyData.id)
|
||||
|
||||
// 更新最后使用时间(优化:只在实际API调用时更新,而不是验证时)
|
||||
// 注意:lastUsedAt的更新已移至recordUsage方法中
|
||||
|
||||
logger.api(`🔓 API key validated successfully: ${keyData.id}`);
|
||||
logger.api(`🔓 API key validated successfully: ${keyData.id}`)
|
||||
|
||||
// 解析限制模型数据
|
||||
let restrictedModels = [];
|
||||
let restrictedModels = []
|
||||
try {
|
||||
restrictedModels = keyData.restrictedModels ? JSON.parse(keyData.restrictedModels) : [];
|
||||
restrictedModels = keyData.restrictedModels ? JSON.parse(keyData.restrictedModels) : []
|
||||
} catch (e) {
|
||||
restrictedModels = [];
|
||||
restrictedModels = []
|
||||
}
|
||||
|
||||
// 解析允许的客户端
|
||||
let allowedClients = [];
|
||||
let allowedClients = []
|
||||
try {
|
||||
allowedClients = keyData.allowedClients ? JSON.parse(keyData.allowedClients) : [];
|
||||
allowedClients = keyData.allowedClients ? JSON.parse(keyData.allowedClients) : []
|
||||
} catch (e) {
|
||||
allowedClients = [];
|
||||
allowedClients = []
|
||||
}
|
||||
|
||||
// 解析标签
|
||||
let tags = [];
|
||||
let tags = []
|
||||
try {
|
||||
tags = keyData.tags ? JSON.parse(keyData.tags) : [];
|
||||
tags = keyData.tags ? JSON.parse(keyData.tags) : []
|
||||
} catch (e) {
|
||||
tags = [];
|
||||
tags = []
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -173,248 +173,306 @@ class ApiKeyService {
|
||||
rateLimitWindow: parseInt(keyData.rateLimitWindow || 0),
|
||||
rateLimitRequests: parseInt(keyData.rateLimitRequests || 0),
|
||||
enableModelRestriction: keyData.enableModelRestriction === 'true',
|
||||
restrictedModels: restrictedModels,
|
||||
restrictedModels,
|
||||
enableClientRestriction: keyData.enableClientRestriction === 'true',
|
||||
allowedClients: allowedClients,
|
||||
allowedClients,
|
||||
dailyCostLimit: parseFloat(keyData.dailyCostLimit || 0),
|
||||
dailyCost: dailyCost || 0,
|
||||
tags: tags,
|
||||
tags,
|
||||
usage
|
||||
}
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ API key validation error:', error);
|
||||
return { valid: false, error: 'Internal validation error' };
|
||||
logger.error('❌ API key validation error:', error)
|
||||
return { valid: false, error: 'Internal validation error' }
|
||||
}
|
||||
}
|
||||
|
||||
// 📋 获取所有API Keys
|
||||
async getAllApiKeys() {
|
||||
try {
|
||||
const apiKeys = await redis.getAllApiKeys();
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const apiKeys = await redis.getAllApiKeys()
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 为每个key添加使用统计和当前并发数
|
||||
for (const key of apiKeys) {
|
||||
key.usage = await redis.getUsageStats(key.id);
|
||||
key.tokenLimit = parseInt(key.tokenLimit);
|
||||
key.concurrencyLimit = parseInt(key.concurrencyLimit || 0);
|
||||
key.rateLimitWindow = parseInt(key.rateLimitWindow || 0);
|
||||
key.rateLimitRequests = parseInt(key.rateLimitRequests || 0);
|
||||
key.currentConcurrency = await redis.getConcurrency(key.id);
|
||||
key.isActive = key.isActive === 'true';
|
||||
key.enableModelRestriction = key.enableModelRestriction === 'true';
|
||||
key.enableClientRestriction = key.enableClientRestriction === 'true';
|
||||
key.permissions = key.permissions || 'all'; // 兼容旧数据
|
||||
key.dailyCostLimit = parseFloat(key.dailyCostLimit || 0);
|
||||
key.dailyCost = await redis.getDailyCost(key.id) || 0;
|
||||
|
||||
key.usage = await redis.getUsageStats(key.id)
|
||||
key.tokenLimit = parseInt(key.tokenLimit)
|
||||
key.concurrencyLimit = parseInt(key.concurrencyLimit || 0)
|
||||
key.rateLimitWindow = parseInt(key.rateLimitWindow || 0)
|
||||
key.rateLimitRequests = parseInt(key.rateLimitRequests || 0)
|
||||
key.currentConcurrency = await redis.getConcurrency(key.id)
|
||||
key.isActive = key.isActive === 'true'
|
||||
key.enableModelRestriction = key.enableModelRestriction === 'true'
|
||||
key.enableClientRestriction = key.enableClientRestriction === 'true'
|
||||
key.permissions = key.permissions || 'all' // 兼容旧数据
|
||||
key.dailyCostLimit = parseFloat(key.dailyCostLimit || 0)
|
||||
key.dailyCost = (await redis.getDailyCost(key.id)) || 0
|
||||
|
||||
// 获取当前时间窗口的请求次数和Token使用量
|
||||
if (key.rateLimitWindow > 0) {
|
||||
const requestCountKey = `rate_limit:requests:${key.id}`;
|
||||
const tokenCountKey = `rate_limit:tokens:${key.id}`;
|
||||
|
||||
key.currentWindowRequests = parseInt(await client.get(requestCountKey) || '0');
|
||||
key.currentWindowTokens = parseInt(await client.get(tokenCountKey) || '0');
|
||||
const requestCountKey = `rate_limit:requests:${key.id}`
|
||||
const tokenCountKey = `rate_limit:tokens:${key.id}`
|
||||
|
||||
key.currentWindowRequests = parseInt((await client.get(requestCountKey)) || '0')
|
||||
key.currentWindowTokens = parseInt((await client.get(tokenCountKey)) || '0')
|
||||
} else {
|
||||
key.currentWindowRequests = 0;
|
||||
key.currentWindowTokens = 0;
|
||||
key.currentWindowRequests = 0
|
||||
key.currentWindowTokens = 0
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
key.restrictedModels = key.restrictedModels ? JSON.parse(key.restrictedModels) : [];
|
||||
key.restrictedModels = key.restrictedModels ? JSON.parse(key.restrictedModels) : []
|
||||
} catch (e) {
|
||||
key.restrictedModels = [];
|
||||
key.restrictedModels = []
|
||||
}
|
||||
try {
|
||||
key.allowedClients = key.allowedClients ? JSON.parse(key.allowedClients) : [];
|
||||
key.allowedClients = key.allowedClients ? JSON.parse(key.allowedClients) : []
|
||||
} catch (e) {
|
||||
key.allowedClients = [];
|
||||
key.allowedClients = []
|
||||
}
|
||||
try {
|
||||
key.tags = key.tags ? JSON.parse(key.tags) : [];
|
||||
key.tags = key.tags ? JSON.parse(key.tags) : []
|
||||
} catch (e) {
|
||||
key.tags = [];
|
||||
key.tags = []
|
||||
}
|
||||
delete key.apiKey; // 不返回哈希后的key
|
||||
delete key.apiKey // 不返回哈希后的key
|
||||
}
|
||||
|
||||
return apiKeys;
|
||||
return apiKeys
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get API keys:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to get API keys:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📝 更新API Key
|
||||
async updateApiKey(keyId, updates) {
|
||||
try {
|
||||
const keyData = await redis.getApiKey(keyId);
|
||||
const keyData = await redis.getApiKey(keyId)
|
||||
if (!keyData || Object.keys(keyData).length === 0) {
|
||||
throw new Error('API key not found');
|
||||
throw new Error('API key not found')
|
||||
}
|
||||
|
||||
// 允许更新的字段
|
||||
const allowedUpdates = ['name', 'description', 'tokenLimit', 'concurrencyLimit', 'rateLimitWindow', 'rateLimitRequests', 'isActive', 'claudeAccountId', 'claudeConsoleAccountId', 'geminiAccountId', 'permissions', 'expiresAt', 'enableModelRestriction', 'restrictedModels', 'enableClientRestriction', 'allowedClients', 'dailyCostLimit', 'tags'];
|
||||
const updatedData = { ...keyData };
|
||||
const allowedUpdates = [
|
||||
'name',
|
||||
'description',
|
||||
'tokenLimit',
|
||||
'concurrencyLimit',
|
||||
'rateLimitWindow',
|
||||
'rateLimitRequests',
|
||||
'isActive',
|
||||
'claudeAccountId',
|
||||
'claudeConsoleAccountId',
|
||||
'geminiAccountId',
|
||||
'permissions',
|
||||
'expiresAt',
|
||||
'enableModelRestriction',
|
||||
'restrictedModels',
|
||||
'enableClientRestriction',
|
||||
'allowedClients',
|
||||
'dailyCostLimit',
|
||||
'tags'
|
||||
]
|
||||
const updatedData = { ...keyData }
|
||||
|
||||
for (const [field, value] of Object.entries(updates)) {
|
||||
if (allowedUpdates.includes(field)) {
|
||||
if (field === 'restrictedModels' || field === 'allowedClients' || field === 'tags') {
|
||||
// 特殊处理数组字段
|
||||
updatedData[field] = JSON.stringify(value || []);
|
||||
updatedData[field] = JSON.stringify(value || [])
|
||||
} else if (field === 'enableModelRestriction' || field === 'enableClientRestriction') {
|
||||
// 布尔值转字符串
|
||||
updatedData[field] = String(value);
|
||||
updatedData[field] = String(value)
|
||||
} else {
|
||||
updatedData[field] = (value != null ? value : '').toString();
|
||||
updatedData[field] = (value !== null && value !== undefined ? value : '').toString()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updatedData.updatedAt = new Date().toISOString();
|
||||
|
||||
updatedData.updatedAt = new Date().toISOString()
|
||||
|
||||
// 更新时不需要重新建立哈希映射,因为API Key本身没有变化
|
||||
await redis.setApiKey(keyId, updatedData);
|
||||
|
||||
logger.success(`📝 Updated API key: ${keyId}`);
|
||||
|
||||
return { success: true };
|
||||
await redis.setApiKey(keyId, updatedData)
|
||||
|
||||
logger.success(`📝 Updated API key: ${keyId}`)
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to update API key:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to update API key:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🗑️ 删除API Key
|
||||
async deleteApiKey(keyId) {
|
||||
try {
|
||||
const result = await redis.deleteApiKey(keyId);
|
||||
|
||||
const result = await redis.deleteApiKey(keyId)
|
||||
|
||||
if (result === 0) {
|
||||
throw new Error('API key not found');
|
||||
throw new Error('API key not found')
|
||||
}
|
||||
|
||||
logger.success(`🗑️ Deleted API key: ${keyId}`);
|
||||
|
||||
return { success: true };
|
||||
|
||||
logger.success(`🗑️ Deleted API key: ${keyId}`)
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to delete API key:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to delete API key:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 记录使用情况(支持缓存token和账户级别统计)
|
||||
async recordUsage(keyId, inputTokens = 0, outputTokens = 0, cacheCreateTokens = 0, cacheReadTokens = 0, model = 'unknown', accountId = null) {
|
||||
async recordUsage(
|
||||
keyId,
|
||||
inputTokens = 0,
|
||||
outputTokens = 0,
|
||||
cacheCreateTokens = 0,
|
||||
cacheReadTokens = 0,
|
||||
model = 'unknown',
|
||||
accountId = null
|
||||
) {
|
||||
try {
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens;
|
||||
|
||||
const totalTokens = inputTokens + outputTokens + cacheCreateTokens + cacheReadTokens
|
||||
|
||||
// 计算费用
|
||||
const CostCalculator = require('../utils/costCalculator');
|
||||
const costInfo = CostCalculator.calculateCost({
|
||||
input_tokens: inputTokens,
|
||||
output_tokens: outputTokens,
|
||||
cache_creation_input_tokens: cacheCreateTokens,
|
||||
cache_read_input_tokens: cacheReadTokens
|
||||
}, model);
|
||||
|
||||
const CostCalculator = require('../utils/costCalculator')
|
||||
const costInfo = CostCalculator.calculateCost(
|
||||
{
|
||||
input_tokens: inputTokens,
|
||||
output_tokens: outputTokens,
|
||||
cache_creation_input_tokens: cacheCreateTokens,
|
||||
cache_read_input_tokens: cacheReadTokens
|
||||
},
|
||||
model
|
||||
)
|
||||
|
||||
// 记录API Key级别的使用统计
|
||||
await redis.incrementTokenUsage(keyId, totalTokens, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model);
|
||||
|
||||
await redis.incrementTokenUsage(
|
||||
keyId,
|
||||
totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model
|
||||
)
|
||||
|
||||
// 记录费用统计
|
||||
if (costInfo.costs.total > 0) {
|
||||
await redis.incrementDailyCost(keyId, costInfo.costs.total);
|
||||
logger.database(`💰 Recorded cost for ${keyId}: $${costInfo.costs.total.toFixed(6)}, model: ${model}`);
|
||||
await redis.incrementDailyCost(keyId, costInfo.costs.total)
|
||||
logger.database(
|
||||
`💰 Recorded cost for ${keyId}: $${costInfo.costs.total.toFixed(6)}, model: ${model}`
|
||||
)
|
||||
} else {
|
||||
logger.debug(`💰 No cost recorded for ${keyId} - zero cost for model: ${model}`);
|
||||
logger.debug(`💰 No cost recorded for ${keyId} - zero cost for model: ${model}`)
|
||||
}
|
||||
|
||||
|
||||
// 获取API Key数据以确定关联的账户
|
||||
const keyData = await redis.getApiKey(keyId);
|
||||
const keyData = await redis.getApiKey(keyId)
|
||||
if (keyData && Object.keys(keyData).length > 0) {
|
||||
// 更新最后使用时间
|
||||
keyData.lastUsedAt = new Date().toISOString();
|
||||
await redis.setApiKey(keyId, keyData);
|
||||
|
||||
keyData.lastUsedAt = new Date().toISOString()
|
||||
await redis.setApiKey(keyId, keyData)
|
||||
|
||||
// 记录账户级别的使用统计(只统计实际处理请求的账户)
|
||||
if (accountId) {
|
||||
await redis.incrementAccountUsage(accountId, totalTokens, inputTokens, outputTokens, cacheCreateTokens, cacheReadTokens, model);
|
||||
logger.database(`📊 Recorded account usage: ${accountId} - ${totalTokens} tokens (API Key: ${keyId})`);
|
||||
await redis.incrementAccountUsage(
|
||||
accountId,
|
||||
totalTokens,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cacheCreateTokens,
|
||||
cacheReadTokens,
|
||||
model
|
||||
)
|
||||
logger.database(
|
||||
`📊 Recorded account usage: ${accountId} - ${totalTokens} tokens (API Key: ${keyId})`
|
||||
)
|
||||
} else {
|
||||
logger.debug('⚠️ No accountId provided for usage recording, skipping account-level statistics');
|
||||
logger.debug(
|
||||
'⚠️ No accountId provided for usage recording, skipping account-level statistics'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const logParts = [`Model: ${model}`, `Input: ${inputTokens}`, `Output: ${outputTokens}`];
|
||||
if (cacheCreateTokens > 0) logParts.push(`Cache Create: ${cacheCreateTokens}`);
|
||||
if (cacheReadTokens > 0) logParts.push(`Cache Read: ${cacheReadTokens}`);
|
||||
logParts.push(`Total: ${totalTokens} tokens`);
|
||||
|
||||
logger.database(`📊 Recorded usage: ${keyId} - ${logParts.join(', ')}`);
|
||||
|
||||
const logParts = [`Model: ${model}`, `Input: ${inputTokens}`, `Output: ${outputTokens}`]
|
||||
if (cacheCreateTokens > 0) {
|
||||
logParts.push(`Cache Create: ${cacheCreateTokens}`)
|
||||
}
|
||||
if (cacheReadTokens > 0) {
|
||||
logParts.push(`Cache Read: ${cacheReadTokens}`)
|
||||
}
|
||||
logParts.push(`Total: ${totalTokens} tokens`)
|
||||
|
||||
logger.database(`📊 Recorded usage: ${keyId} - ${logParts.join(', ')}`)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to record usage:', error);
|
||||
logger.error('❌ Failed to record usage:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 🔐 生成密钥
|
||||
_generateSecretKey() {
|
||||
return crypto.randomBytes(32).toString('hex');
|
||||
return crypto.randomBytes(32).toString('hex')
|
||||
}
|
||||
|
||||
// 🔒 哈希API Key
|
||||
_hashApiKey(apiKey) {
|
||||
return crypto.createHash('sha256').update(apiKey + config.security.encryptionKey).digest('hex');
|
||||
return crypto
|
||||
.createHash('sha256')
|
||||
.update(apiKey + config.security.encryptionKey)
|
||||
.digest('hex')
|
||||
}
|
||||
|
||||
// 📈 获取使用统计
|
||||
async getUsageStats(keyId) {
|
||||
return await redis.getUsageStats(keyId);
|
||||
return await redis.getUsageStats(keyId)
|
||||
}
|
||||
|
||||
// 📊 获取账户使用统计
|
||||
async getAccountUsageStats(accountId) {
|
||||
return await redis.getAccountUsageStats(accountId);
|
||||
return await redis.getAccountUsageStats(accountId)
|
||||
}
|
||||
|
||||
// 📈 获取所有账户使用统计
|
||||
async getAllAccountsUsageStats() {
|
||||
return await redis.getAllAccountsUsageStats();
|
||||
return await redis.getAllAccountsUsageStats()
|
||||
}
|
||||
|
||||
|
||||
// 🧹 清理过期的API Keys
|
||||
async cleanupExpiredKeys() {
|
||||
try {
|
||||
const apiKeys = await redis.getAllApiKeys();
|
||||
const now = new Date();
|
||||
let cleanedCount = 0;
|
||||
const apiKeys = await redis.getAllApiKeys()
|
||||
const now = new Date()
|
||||
let cleanedCount = 0
|
||||
|
||||
for (const key of apiKeys) {
|
||||
// 检查是否已过期且仍处于激活状态
|
||||
if (key.expiresAt && new Date(key.expiresAt) < now && key.isActive === 'true') {
|
||||
// 将过期的 API Key 标记为禁用状态,而不是直接删除
|
||||
await this.updateApiKey(key.id, { isActive: false });
|
||||
logger.info(`🔒 API Key ${key.id} (${key.name}) has expired and been disabled`);
|
||||
cleanedCount++;
|
||||
await this.updateApiKey(key.id, { isActive: false })
|
||||
logger.info(`🔒 API Key ${key.id} (${key.name}) has expired and been disabled`)
|
||||
cleanedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if (cleanedCount > 0) {
|
||||
logger.success(`🧹 Disabled ${cleanedCount} expired API keys`);
|
||||
logger.success(`🧹 Disabled ${cleanedCount} expired API keys`)
|
||||
}
|
||||
|
||||
return cleanedCount;
|
||||
return cleanedCount
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to cleanup expired keys:', error);
|
||||
return 0;
|
||||
logger.error('❌ Failed to cleanup expired keys:', error)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 导出实例和单独的方法
|
||||
const apiKeyService = new ApiKeyService();
|
||||
const apiKeyService = new ApiKeyService()
|
||||
|
||||
// 为了方便其他服务调用,导出 recordUsage 方法
|
||||
apiKeyService.recordUsageMetrics = apiKeyService.recordUsage.bind(apiKeyService);
|
||||
apiKeyService.recordUsageMetrics = apiKeyService.recordUsage.bind(apiKeyService)
|
||||
|
||||
module.exports = apiKeyService;
|
||||
module.exports = apiKeyService
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const crypto = require('crypto');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const bedrockRelayService = require('./bedrockRelayService');
|
||||
const { v4: uuidv4 } = require('uuid')
|
||||
const crypto = require('crypto')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
const bedrockRelayService = require('./bedrockRelayService')
|
||||
|
||||
class BedrockAccountService {
|
||||
constructor() {
|
||||
// 加密相关常量
|
||||
this.ENCRYPTION_ALGORITHM = 'aes-256-cbc';
|
||||
this.ENCRYPTION_SALT = 'salt';
|
||||
this.ENCRYPTION_ALGORITHM = 'aes-256-cbc'
|
||||
this.ENCRYPTION_SALT = 'salt'
|
||||
}
|
||||
|
||||
// 🏢 创建Bedrock账户
|
||||
@@ -25,11 +25,11 @@ class BedrockAccountService {
|
||||
priority = 50, // 调度优先级 (1-100,数字越小优先级越高)
|
||||
schedulable = true, // 是否可被调度
|
||||
credentialType = 'default' // 'default', 'access_key', 'bearer_token'
|
||||
} = options;
|
||||
} = options
|
||||
|
||||
const accountId = uuidv4();
|
||||
const accountId = uuidv4()
|
||||
|
||||
let accountData = {
|
||||
const accountData = {
|
||||
id: accountId,
|
||||
name,
|
||||
description,
|
||||
@@ -43,17 +43,17 @@ class BedrockAccountService {
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
type: 'bedrock' // 标识这是Bedrock账户
|
||||
};
|
||||
}
|
||||
|
||||
// 加密存储AWS凭证
|
||||
if (awsCredentials) {
|
||||
accountData.awsCredentials = this._encryptAwsCredentials(awsCredentials);
|
||||
accountData.awsCredentials = this._encryptAwsCredentials(awsCredentials)
|
||||
}
|
||||
|
||||
const client = redis.getClientSafe();
|
||||
await client.set(`bedrock_account:${accountId}`, JSON.stringify(accountData));
|
||||
const client = redis.getClientSafe()
|
||||
await client.set(`bedrock_account:${accountId}`, JSON.stringify(accountData))
|
||||
|
||||
logger.info(`✅ 创建Bedrock账户成功 - ID: ${accountId}, 名称: ${name}, 区域: ${region}`);
|
||||
logger.info(`✅ 创建Bedrock账户成功 - ID: ${accountId}, 名称: ${name}, 区域: ${region}`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
@@ -71,48 +71,48 @@ class BedrockAccountService {
|
||||
createdAt: accountData.createdAt,
|
||||
type: 'bedrock'
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 获取账户信息
|
||||
async getAccount(accountId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const accountData = await client.get(`bedrock_account:${accountId}`);
|
||||
const client = redis.getClientSafe()
|
||||
const accountData = await client.get(`bedrock_account:${accountId}`)
|
||||
if (!accountData) {
|
||||
return { success: false, error: 'Account not found' };
|
||||
return { success: false, error: 'Account not found' }
|
||||
}
|
||||
|
||||
const account = JSON.parse(accountData);
|
||||
const account = JSON.parse(accountData)
|
||||
|
||||
// 解密AWS凭证用于内部使用
|
||||
if (account.awsCredentials) {
|
||||
account.awsCredentials = this._decryptAwsCredentials(account.awsCredentials);
|
||||
account.awsCredentials = this._decryptAwsCredentials(account.awsCredentials)
|
||||
}
|
||||
|
||||
logger.debug(`🔍 获取Bedrock账户 - ID: ${accountId}, 名称: ${account.name}`);
|
||||
logger.debug(`🔍 获取Bedrock账户 - ID: ${accountId}, 名称: ${account.name}`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
data: account
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`❌ 获取Bedrock账户失败 - ID: ${accountId}`, error);
|
||||
return { success: false, error: error.message };
|
||||
logger.error(`❌ 获取Bedrock账户失败 - ID: ${accountId}`, error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
}
|
||||
|
||||
// 📋 获取所有账户列表
|
||||
async getAllAccounts() {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const keys = await client.keys('bedrock_account:*');
|
||||
const accounts = [];
|
||||
const client = redis.getClientSafe()
|
||||
const keys = await client.keys('bedrock_account:*')
|
||||
const accounts = []
|
||||
|
||||
for (const key of keys) {
|
||||
const accountData = await client.get(key);
|
||||
const accountData = await client.get(key)
|
||||
if (accountData) {
|
||||
const account = JSON.parse(accountData);
|
||||
const account = JSON.parse(accountData)
|
||||
|
||||
// 返回给前端时,不包含敏感信息,只显示掩码
|
||||
accounts.push({
|
||||
@@ -130,25 +130,27 @@ class BedrockAccountService {
|
||||
updatedAt: account.updatedAt,
|
||||
type: 'bedrock',
|
||||
hasCredentials: !!account.awsCredentials
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 按优先级和名称排序
|
||||
accounts.sort((a, b) => {
|
||||
if (a.priority !== b.priority) return a.priority - b.priority;
|
||||
return a.name.localeCompare(b.name);
|
||||
});
|
||||
if (a.priority !== b.priority) {
|
||||
return a.priority - b.priority
|
||||
}
|
||||
return a.name.localeCompare(b.name)
|
||||
})
|
||||
|
||||
logger.debug(`📋 获取所有Bedrock账户 - 共 ${accounts.length} 个`);
|
||||
logger.debug(`📋 获取所有Bedrock账户 - 共 ${accounts.length} 个`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
data: accounts
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取Bedrock账户列表失败', error);
|
||||
return { success: false, error: error.message };
|
||||
logger.error('❌ 获取Bedrock账户列表失败', error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,44 +158,62 @@ class BedrockAccountService {
|
||||
async updateAccount(accountId, updates = {}) {
|
||||
try {
|
||||
// 获取原始账户数据(不解密凭证)
|
||||
const client = redis.getClientSafe();
|
||||
const accountData = await client.get(`bedrock_account:${accountId}`);
|
||||
const client = redis.getClientSafe()
|
||||
const accountData = await client.get(`bedrock_account:${accountId}`)
|
||||
if (!accountData) {
|
||||
return { success: false, error: 'Account not found' };
|
||||
return { success: false, error: 'Account not found' }
|
||||
}
|
||||
|
||||
const account = JSON.parse(accountData);
|
||||
const account = JSON.parse(accountData)
|
||||
|
||||
// 更新字段
|
||||
if (updates.name !== undefined) account.name = updates.name;
|
||||
if (updates.description !== undefined) account.description = updates.description;
|
||||
if (updates.region !== undefined) account.region = updates.region;
|
||||
if (updates.defaultModel !== undefined) account.defaultModel = updates.defaultModel;
|
||||
if (updates.isActive !== undefined) account.isActive = updates.isActive;
|
||||
if (updates.accountType !== undefined) account.accountType = updates.accountType;
|
||||
if (updates.priority !== undefined) account.priority = updates.priority;
|
||||
if (updates.schedulable !== undefined) account.schedulable = updates.schedulable;
|
||||
if (updates.credentialType !== undefined) account.credentialType = updates.credentialType;
|
||||
if (updates.name !== undefined) {
|
||||
account.name = updates.name
|
||||
}
|
||||
if (updates.description !== undefined) {
|
||||
account.description = updates.description
|
||||
}
|
||||
if (updates.region !== undefined) {
|
||||
account.region = updates.region
|
||||
}
|
||||
if (updates.defaultModel !== undefined) {
|
||||
account.defaultModel = updates.defaultModel
|
||||
}
|
||||
if (updates.isActive !== undefined) {
|
||||
account.isActive = updates.isActive
|
||||
}
|
||||
if (updates.accountType !== undefined) {
|
||||
account.accountType = updates.accountType
|
||||
}
|
||||
if (updates.priority !== undefined) {
|
||||
account.priority = updates.priority
|
||||
}
|
||||
if (updates.schedulable !== undefined) {
|
||||
account.schedulable = updates.schedulable
|
||||
}
|
||||
if (updates.credentialType !== undefined) {
|
||||
account.credentialType = updates.credentialType
|
||||
}
|
||||
|
||||
// 更新AWS凭证
|
||||
if (updates.awsCredentials !== undefined) {
|
||||
if (updates.awsCredentials) {
|
||||
account.awsCredentials = this._encryptAwsCredentials(updates.awsCredentials);
|
||||
account.awsCredentials = this._encryptAwsCredentials(updates.awsCredentials)
|
||||
} else {
|
||||
delete account.awsCredentials;
|
||||
delete account.awsCredentials
|
||||
}
|
||||
} else if (account.awsCredentials && account.awsCredentials.accessKeyId) {
|
||||
// 如果没有提供新凭证但现有凭证是明文格式,重新加密
|
||||
const plainCredentials = account.awsCredentials;
|
||||
account.awsCredentials = this._encryptAwsCredentials(plainCredentials);
|
||||
logger.info(`🔐 重新加密Bedrock账户凭证 - ID: ${accountId}`);
|
||||
const plainCredentials = account.awsCredentials
|
||||
account.awsCredentials = this._encryptAwsCredentials(plainCredentials)
|
||||
logger.info(`🔐 重新加密Bedrock账户凭证 - ID: ${accountId}`)
|
||||
}
|
||||
|
||||
account.updatedAt = new Date().toISOString();
|
||||
account.updatedAt = new Date().toISOString()
|
||||
|
||||
await client.set(`bedrock_account:${accountId}`, JSON.stringify(account));
|
||||
await client.set(`bedrock_account:${accountId}`, JSON.stringify(account))
|
||||
|
||||
logger.info(`✅ 更新Bedrock账户成功 - ID: ${accountId}, 名称: ${account.name}`);
|
||||
logger.info(`✅ 更新Bedrock账户成功 - ID: ${accountId}, 名称: ${account.name}`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
@@ -211,87 +231,87 @@ class BedrockAccountService {
|
||||
updatedAt: account.updatedAt,
|
||||
type: 'bedrock'
|
||||
}
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`❌ 更新Bedrock账户失败 - ID: ${accountId}`, error);
|
||||
return { success: false, error: error.message };
|
||||
logger.error(`❌ 更新Bedrock账户失败 - ID: ${accountId}`, error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
}
|
||||
|
||||
// 🗑️ 删除账户
|
||||
async deleteAccount(accountId) {
|
||||
try {
|
||||
const accountResult = await this.getAccount(accountId);
|
||||
const accountResult = await this.getAccount(accountId)
|
||||
if (!accountResult.success) {
|
||||
return accountResult;
|
||||
return accountResult
|
||||
}
|
||||
|
||||
const client = redis.getClientSafe();
|
||||
await client.del(`bedrock_account:${accountId}`);
|
||||
const client = redis.getClientSafe()
|
||||
await client.del(`bedrock_account:${accountId}`)
|
||||
|
||||
logger.info(`✅ 删除Bedrock账户成功 - ID: ${accountId}`);
|
||||
logger.info(`✅ 删除Bedrock账户成功 - ID: ${accountId}`)
|
||||
|
||||
return { success: true };
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ 删除Bedrock账户失败 - ID: ${accountId}`, error);
|
||||
return { success: false, error: error.message };
|
||||
logger.error(`❌ 删除Bedrock账户失败 - ID: ${accountId}`, error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
}
|
||||
|
||||
// 🎯 选择可用的Bedrock账户 (用于请求转发)
|
||||
async selectAvailableAccount() {
|
||||
try {
|
||||
const accountsResult = await this.getAllAccounts();
|
||||
const accountsResult = await this.getAllAccounts()
|
||||
if (!accountsResult.success) {
|
||||
return { success: false, error: 'Failed to get accounts' };
|
||||
return { success: false, error: 'Failed to get accounts' }
|
||||
}
|
||||
|
||||
const availableAccounts = accountsResult.data.filter(account =>
|
||||
account.isActive && account.schedulable
|
||||
);
|
||||
const availableAccounts = accountsResult.data.filter(
|
||||
(account) => account.isActive && account.schedulable
|
||||
)
|
||||
|
||||
if (availableAccounts.length === 0) {
|
||||
return { success: false, error: 'No available Bedrock accounts' };
|
||||
return { success: false, error: 'No available Bedrock accounts' }
|
||||
}
|
||||
|
||||
// 简单的轮询选择策略 - 选择优先级最高的账户
|
||||
const selectedAccount = availableAccounts[0];
|
||||
const selectedAccount = availableAccounts[0]
|
||||
|
||||
// 获取完整账户信息(包含解密的凭证)
|
||||
const fullAccountResult = await this.getAccount(selectedAccount.id);
|
||||
const fullAccountResult = await this.getAccount(selectedAccount.id)
|
||||
if (!fullAccountResult.success) {
|
||||
return { success: false, error: 'Failed to get selected account details' };
|
||||
return { success: false, error: 'Failed to get selected account details' }
|
||||
}
|
||||
|
||||
logger.debug(`🎯 选择Bedrock账户 - ID: ${selectedAccount.id}, 名称: ${selectedAccount.name}`);
|
||||
logger.debug(`🎯 选择Bedrock账户 - ID: ${selectedAccount.id}, 名称: ${selectedAccount.name}`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
data: fullAccountResult.data
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ 选择Bedrock账户失败', error);
|
||||
return { success: false, error: error.message };
|
||||
logger.error('❌ 选择Bedrock账户失败', error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
}
|
||||
|
||||
// 🧪 测试账户连接
|
||||
async testAccount(accountId) {
|
||||
try {
|
||||
const accountResult = await this.getAccount(accountId);
|
||||
const accountResult = await this.getAccount(accountId)
|
||||
if (!accountResult.success) {
|
||||
return accountResult;
|
||||
return accountResult
|
||||
}
|
||||
|
||||
const account = accountResult.data;
|
||||
const account = accountResult.data
|
||||
|
||||
logger.info(`🧪 测试Bedrock账户连接 - ID: ${accountId}, 名称: ${account.name}`);
|
||||
logger.info(`🧪 测试Bedrock账户连接 - ID: ${accountId}, 名称: ${account.name}`)
|
||||
|
||||
// 尝试获取模型列表来测试连接
|
||||
const models = await bedrockRelayService.getAvailableModels(account);
|
||||
const models = await bedrockRelayService.getAvailableModels(account)
|
||||
|
||||
if (models && models.length > 0) {
|
||||
logger.info(`✅ Bedrock账户测试成功 - ID: ${accountId}, 发现 ${models.length} 个模型`);
|
||||
logger.info(`✅ Bedrock账户测试成功 - ID: ${accountId}, 发现 ${models.length} 个模型`)
|
||||
return {
|
||||
success: true,
|
||||
data: {
|
||||
@@ -300,40 +320,40 @@ class BedrockAccountService {
|
||||
region: account.region,
|
||||
credentialType: account.credentialType
|
||||
}
|
||||
};
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
success: false,
|
||||
error: 'Unable to retrieve models from Bedrock'
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`❌ 测试Bedrock账户失败 - ID: ${accountId}`, error);
|
||||
logger.error(`❌ 测试Bedrock账户失败 - ID: ${accountId}`, error)
|
||||
return {
|
||||
success: false,
|
||||
error: error.message
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 🔐 加密AWS凭证
|
||||
_encryptAwsCredentials(credentials) {
|
||||
try {
|
||||
const key = crypto.createHash('sha256').update(config.security.encryptionKey).digest();
|
||||
const iv = crypto.randomBytes(16);
|
||||
const cipher = crypto.createCipheriv(this.ENCRYPTION_ALGORITHM, key, iv);
|
||||
const key = crypto.createHash('sha256').update(config.security.encryptionKey).digest()
|
||||
const iv = crypto.randomBytes(16)
|
||||
const cipher = crypto.createCipheriv(this.ENCRYPTION_ALGORITHM, key, iv)
|
||||
|
||||
const credentialsString = JSON.stringify(credentials);
|
||||
let encrypted = cipher.update(credentialsString, 'utf8', 'hex');
|
||||
encrypted += cipher.final('hex');
|
||||
const credentialsString = JSON.stringify(credentials)
|
||||
let encrypted = cipher.update(credentialsString, 'utf8', 'hex')
|
||||
encrypted += cipher.final('hex')
|
||||
|
||||
return {
|
||||
encrypted: encrypted,
|
||||
encrypted,
|
||||
iv: iv.toString('hex')
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ AWS凭证加密失败', error);
|
||||
throw new Error('Credentials encryption failed');
|
||||
logger.error('❌ AWS凭证加密失败', error)
|
||||
throw new Error('Credentials encryption failed')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,70 +362,71 @@ class BedrockAccountService {
|
||||
try {
|
||||
// 检查数据格式
|
||||
if (!encryptedData || typeof encryptedData !== 'object') {
|
||||
logger.error('❌ 无效的加密数据格式:', encryptedData);
|
||||
throw new Error('Invalid encrypted data format');
|
||||
logger.error('❌ 无效的加密数据格式:', encryptedData)
|
||||
throw new Error('Invalid encrypted data format')
|
||||
}
|
||||
|
||||
// 检查是否为加密格式 (有 encrypted 和 iv 字段)
|
||||
if (encryptedData.encrypted && encryptedData.iv) {
|
||||
// 加密数据 - 进行解密
|
||||
const key = crypto.createHash('sha256').update(config.security.encryptionKey).digest();
|
||||
const iv = Buffer.from(encryptedData.iv, 'hex');
|
||||
const decipher = crypto.createDecipheriv(this.ENCRYPTION_ALGORITHM, key, iv);
|
||||
const key = crypto.createHash('sha256').update(config.security.encryptionKey).digest()
|
||||
const iv = Buffer.from(encryptedData.iv, 'hex')
|
||||
const decipher = crypto.createDecipheriv(this.ENCRYPTION_ALGORITHM, key, iv)
|
||||
|
||||
let decrypted = decipher.update(encryptedData.encrypted, 'hex', 'utf8');
|
||||
decrypted += decipher.final('utf8');
|
||||
let decrypted = decipher.update(encryptedData.encrypted, 'hex', 'utf8')
|
||||
decrypted += decipher.final('utf8')
|
||||
|
||||
return JSON.parse(decrypted);
|
||||
return JSON.parse(decrypted)
|
||||
} else if (encryptedData.accessKeyId) {
|
||||
// 纯文本数据 - 直接返回 (向后兼容)
|
||||
logger.warn('⚠️ 发现未加密的AWS凭证,建议更新账户以启用加密');
|
||||
return encryptedData;
|
||||
logger.warn('⚠️ 发现未加密的AWS凭证,建议更新账户以启用加密')
|
||||
return encryptedData
|
||||
} else {
|
||||
// 既不是加密格式也不是有效的凭证格式
|
||||
logger.error('❌ 缺少加密数据字段:', {
|
||||
hasEncrypted: !!encryptedData.encrypted,
|
||||
hasIv: !!encryptedData.iv,
|
||||
hasAccessKeyId: !!encryptedData.accessKeyId
|
||||
});
|
||||
throw new Error('Missing encrypted data fields or valid credentials');
|
||||
})
|
||||
throw new Error('Missing encrypted data fields or valid credentials')
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ AWS凭证解密失败', error);
|
||||
throw new Error('Credentials decryption failed');
|
||||
logger.error('❌ AWS凭证解密失败', error)
|
||||
throw new Error('Credentials decryption failed')
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 获取账户统计信息
|
||||
async getAccountStats() {
|
||||
try {
|
||||
const accountsResult = await this.getAllAccounts();
|
||||
const accountsResult = await this.getAllAccounts()
|
||||
if (!accountsResult.success) {
|
||||
return { success: false, error: accountsResult.error };
|
||||
return { success: false, error: accountsResult.error }
|
||||
}
|
||||
|
||||
const accounts = accountsResult.data;
|
||||
const accounts = accountsResult.data
|
||||
const stats = {
|
||||
total: accounts.length,
|
||||
active: accounts.filter(acc => acc.isActive).length,
|
||||
inactive: accounts.filter(acc => !acc.isActive).length,
|
||||
schedulable: accounts.filter(acc => acc.schedulable).length,
|
||||
active: accounts.filter((acc) => acc.isActive).length,
|
||||
inactive: accounts.filter((acc) => !acc.isActive).length,
|
||||
schedulable: accounts.filter((acc) => acc.schedulable).length,
|
||||
byRegion: {},
|
||||
byCredentialType: {}
|
||||
};
|
||||
}
|
||||
|
||||
// 按区域统计
|
||||
accounts.forEach(acc => {
|
||||
stats.byRegion[acc.region] = (stats.byRegion[acc.region] || 0) + 1;
|
||||
stats.byCredentialType[acc.credentialType] = (stats.byCredentialType[acc.credentialType] || 0) + 1;
|
||||
});
|
||||
accounts.forEach((acc) => {
|
||||
stats.byRegion[acc.region] = (stats.byRegion[acc.region] || 0) + 1
|
||||
stats.byCredentialType[acc.credentialType] =
|
||||
(stats.byCredentialType[acc.credentialType] || 0) + 1
|
||||
})
|
||||
|
||||
return { success: true, data: stats };
|
||||
return { success: true, data: stats }
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取Bedrock账户统计失败', error);
|
||||
return { success: false, error: error.message };
|
||||
logger.error('❌ 获取Bedrock账户统计失败', error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new BedrockAccountService();
|
||||
module.exports = new BedrockAccountService()
|
||||
|
||||
@@ -1,38 +1,44 @@
|
||||
const { BedrockRuntimeClient, InvokeModelCommand, InvokeModelWithResponseStreamCommand } = require('@aws-sdk/client-bedrock-runtime');
|
||||
const { fromEnv } = require('@aws-sdk/credential-providers');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const {
|
||||
BedrockRuntimeClient,
|
||||
InvokeModelCommand,
|
||||
InvokeModelWithResponseStreamCommand
|
||||
} = require('@aws-sdk/client-bedrock-runtime')
|
||||
const { fromEnv } = require('@aws-sdk/credential-providers')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
|
||||
class BedrockRelayService {
|
||||
constructor() {
|
||||
this.defaultRegion = process.env.AWS_REGION || config.bedrock?.defaultRegion || 'us-east-1';
|
||||
this.smallFastModelRegion = process.env.ANTHROPIC_SMALL_FAST_MODEL_AWS_REGION || this.defaultRegion;
|
||||
this.defaultRegion = process.env.AWS_REGION || config.bedrock?.defaultRegion || 'us-east-1'
|
||||
this.smallFastModelRegion =
|
||||
process.env.ANTHROPIC_SMALL_FAST_MODEL_AWS_REGION || this.defaultRegion
|
||||
|
||||
// 默认模型配置
|
||||
this.defaultModel = process.env.ANTHROPIC_MODEL || 'us.anthropic.claude-sonnet-4-20250514-v1:0';
|
||||
this.defaultSmallModel = process.env.ANTHROPIC_SMALL_FAST_MODEL || 'us.anthropic.claude-3-5-haiku-20241022-v1:0';
|
||||
this.defaultModel = process.env.ANTHROPIC_MODEL || 'us.anthropic.claude-sonnet-4-20250514-v1:0'
|
||||
this.defaultSmallModel =
|
||||
process.env.ANTHROPIC_SMALL_FAST_MODEL || 'us.anthropic.claude-3-5-haiku-20241022-v1:0'
|
||||
|
||||
// Token配置
|
||||
this.maxOutputTokens = parseInt(process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS) || 4096;
|
||||
this.maxThinkingTokens = parseInt(process.env.MAX_THINKING_TOKENS) || 1024;
|
||||
this.enablePromptCaching = process.env.DISABLE_PROMPT_CACHING !== '1';
|
||||
this.maxOutputTokens = parseInt(process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS) || 4096
|
||||
this.maxThinkingTokens = parseInt(process.env.MAX_THINKING_TOKENS) || 1024
|
||||
this.enablePromptCaching = process.env.DISABLE_PROMPT_CACHING !== '1'
|
||||
|
||||
// 创建Bedrock客户端
|
||||
this.clients = new Map(); // 缓存不同区域的客户端
|
||||
this.clients = new Map() // 缓存不同区域的客户端
|
||||
}
|
||||
|
||||
// 获取或创建Bedrock客户端
|
||||
_getBedrockClient(region = null, bedrockAccount = null) {
|
||||
const targetRegion = region || this.defaultRegion;
|
||||
const clientKey = `${targetRegion}-${bedrockAccount?.id || 'default'}`;
|
||||
const targetRegion = region || this.defaultRegion
|
||||
const clientKey = `${targetRegion}-${bedrockAccount?.id || 'default'}`
|
||||
|
||||
if (this.clients.has(clientKey)) {
|
||||
return this.clients.get(clientKey);
|
||||
return this.clients.get(clientKey)
|
||||
}
|
||||
|
||||
const clientConfig = {
|
||||
region: targetRegion
|
||||
};
|
||||
}
|
||||
|
||||
// 如果账户配置了特定的AWS凭证,使用它们
|
||||
if (bedrockAccount?.awsCredentials) {
|
||||
@@ -40,51 +46,55 @@ class BedrockRelayService {
|
||||
accessKeyId: bedrockAccount.awsCredentials.accessKeyId,
|
||||
secretAccessKey: bedrockAccount.awsCredentials.secretAccessKey,
|
||||
sessionToken: bedrockAccount.awsCredentials.sessionToken
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// 检查是否有环境变量凭证
|
||||
if (process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) {
|
||||
clientConfig.credentials = fromEnv();
|
||||
clientConfig.credentials = fromEnv()
|
||||
} else {
|
||||
throw new Error('AWS凭证未配置。请在Bedrock账户中配置AWS访问密钥,或设置环境变量AWS_ACCESS_KEY_ID和AWS_SECRET_ACCESS_KEY');
|
||||
throw new Error(
|
||||
'AWS凭证未配置。请在Bedrock账户中配置AWS访问密钥,或设置环境变量AWS_ACCESS_KEY_ID和AWS_SECRET_ACCESS_KEY'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const client = new BedrockRuntimeClient(clientConfig);
|
||||
this.clients.set(clientKey, client);
|
||||
const client = new BedrockRuntimeClient(clientConfig)
|
||||
this.clients.set(clientKey, client)
|
||||
|
||||
logger.debug(`🔧 Created Bedrock client for region: ${targetRegion}, account: ${bedrockAccount?.name || 'default'}`);
|
||||
return client;
|
||||
logger.debug(
|
||||
`🔧 Created Bedrock client for region: ${targetRegion}, account: ${bedrockAccount?.name || 'default'}`
|
||||
)
|
||||
return client
|
||||
}
|
||||
|
||||
// 处理非流式请求
|
||||
async handleNonStreamRequest(requestBody, bedrockAccount = null) {
|
||||
try {
|
||||
const modelId = this._selectModel(requestBody, bedrockAccount);
|
||||
const region = this._selectRegion(modelId, bedrockAccount);
|
||||
const client = this._getBedrockClient(region, bedrockAccount);
|
||||
const modelId = this._selectModel(requestBody, bedrockAccount)
|
||||
const region = this._selectRegion(modelId, bedrockAccount)
|
||||
const client = this._getBedrockClient(region, bedrockAccount)
|
||||
|
||||
// 转换请求格式为Bedrock格式
|
||||
const bedrockPayload = this._convertToBedrockFormat(requestBody);
|
||||
const bedrockPayload = this._convertToBedrockFormat(requestBody)
|
||||
|
||||
const command = new InvokeModelCommand({
|
||||
modelId: modelId,
|
||||
modelId,
|
||||
body: JSON.stringify(bedrockPayload),
|
||||
contentType: 'application/json',
|
||||
accept: 'application/json'
|
||||
});
|
||||
})
|
||||
|
||||
logger.debug(`🚀 Bedrock非流式请求 - 模型: ${modelId}, 区域: ${region}`);
|
||||
logger.debug(`🚀 Bedrock非流式请求 - 模型: ${modelId}, 区域: ${region}`)
|
||||
|
||||
const startTime = Date.now();
|
||||
const response = await client.send(command);
|
||||
const duration = Date.now() - startTime;
|
||||
const startTime = Date.now()
|
||||
const response = await client.send(command)
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
// 解析响应
|
||||
const responseBody = JSON.parse(new TextDecoder().decode(response.body));
|
||||
const claudeResponse = this._convertFromBedrockFormat(responseBody);
|
||||
const responseBody = JSON.parse(new TextDecoder().decode(response.body))
|
||||
const claudeResponse = this._convertFromBedrockFormat(responseBody)
|
||||
|
||||
logger.info(`✅ Bedrock请求完成 - 模型: ${modelId}, 耗时: ${duration}ms`);
|
||||
logger.info(`✅ Bedrock请求完成 - 模型: ${modelId}, 耗时: ${duration}ms`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
@@ -92,127 +102,129 @@ class BedrockRelayService {
|
||||
usage: claudeResponse.usage,
|
||||
model: modelId,
|
||||
duration
|
||||
};
|
||||
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Bedrock非流式请求失败:', error);
|
||||
throw this._handleBedrockError(error);
|
||||
logger.error('❌ Bedrock非流式请求失败:', error)
|
||||
throw this._handleBedrockError(error)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理流式请求
|
||||
async handleStreamRequest(requestBody, bedrockAccount = null, res) {
|
||||
try {
|
||||
const modelId = this._selectModel(requestBody, bedrockAccount);
|
||||
const region = this._selectRegion(modelId, bedrockAccount);
|
||||
const client = this._getBedrockClient(region, bedrockAccount);
|
||||
const modelId = this._selectModel(requestBody, bedrockAccount)
|
||||
const region = this._selectRegion(modelId, bedrockAccount)
|
||||
const client = this._getBedrockClient(region, bedrockAccount)
|
||||
|
||||
// 转换请求格式为Bedrock格式
|
||||
const bedrockPayload = this._convertToBedrockFormat(requestBody);
|
||||
const bedrockPayload = this._convertToBedrockFormat(requestBody)
|
||||
|
||||
const command = new InvokeModelWithResponseStreamCommand({
|
||||
modelId: modelId,
|
||||
modelId,
|
||||
body: JSON.stringify(bedrockPayload),
|
||||
contentType: 'application/json',
|
||||
accept: 'application/json'
|
||||
});
|
||||
})
|
||||
|
||||
logger.debug(`🌊 Bedrock流式请求 - 模型: ${modelId}, 区域: ${region}`);
|
||||
logger.debug(`🌊 Bedrock流式请求 - 模型: ${modelId}, 区域: ${region}`)
|
||||
|
||||
const startTime = Date.now();
|
||||
const response = await client.send(command);
|
||||
const startTime = Date.now()
|
||||
const response = await client.send(command)
|
||||
|
||||
// 设置SSE响应头
|
||||
res.writeHead(200, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
Connection: 'keep-alive',
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
'Access-Control-Allow-Headers': 'Content-Type, Authorization'
|
||||
});
|
||||
})
|
||||
|
||||
let totalUsage = null;
|
||||
let isFirstChunk = true;
|
||||
let totalUsage = null
|
||||
let isFirstChunk = true
|
||||
|
||||
// 处理流式响应
|
||||
for await (const chunk of response.body) {
|
||||
if (chunk.chunk) {
|
||||
const chunkData = JSON.parse(new TextDecoder().decode(chunk.chunk.bytes));
|
||||
const claudeEvent = this._convertBedrockStreamToClaudeFormat(chunkData, isFirstChunk);
|
||||
const chunkData = JSON.parse(new TextDecoder().decode(chunk.chunk.bytes))
|
||||
const claudeEvent = this._convertBedrockStreamToClaudeFormat(chunkData, isFirstChunk)
|
||||
|
||||
if (claudeEvent) {
|
||||
// 发送SSE事件
|
||||
res.write(`event: ${claudeEvent.type}\n`);
|
||||
res.write(`data: ${JSON.stringify(claudeEvent.data)}\n\n`);
|
||||
res.write(`event: ${claudeEvent.type}\n`)
|
||||
res.write(`data: ${JSON.stringify(claudeEvent.data)}\n\n`)
|
||||
|
||||
// 提取使用统计
|
||||
if (claudeEvent.type === 'message_stop' && claudeEvent.data.usage) {
|
||||
totalUsage = claudeEvent.data.usage;
|
||||
totalUsage = claudeEvent.data.usage
|
||||
}
|
||||
|
||||
isFirstChunk = false;
|
||||
isFirstChunk = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
logger.info(`✅ Bedrock流式请求完成 - 模型: ${modelId}, 耗时: ${duration}ms`);
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`✅ Bedrock流式请求完成 - 模型: ${modelId}, 耗时: ${duration}ms`)
|
||||
|
||||
// 发送结束事件
|
||||
res.write('event: done\n');
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.end();
|
||||
res.write('event: done\n')
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
|
||||
return {
|
||||
success: true,
|
||||
usage: totalUsage,
|
||||
model: modelId,
|
||||
duration
|
||||
};
|
||||
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Bedrock流式请求失败:', error);
|
||||
logger.error('❌ Bedrock流式请求失败:', error)
|
||||
|
||||
// 发送错误事件
|
||||
if (!res.headersSent) {
|
||||
res.writeHead(500, { 'Content-Type': 'application/json' });
|
||||
res.writeHead(500, { 'Content-Type': 'application/json' })
|
||||
}
|
||||
|
||||
res.write('event: error\n');
|
||||
res.write(`data: ${JSON.stringify({ error: this._handleBedrockError(error).message })}\n\n`);
|
||||
res.end();
|
||||
res.write('event: error\n')
|
||||
res.write(`data: ${JSON.stringify({ error: this._handleBedrockError(error).message })}\n\n`)
|
||||
res.end()
|
||||
|
||||
throw this._handleBedrockError(error);
|
||||
throw this._handleBedrockError(error)
|
||||
}
|
||||
}
|
||||
|
||||
// 选择使用的模型
|
||||
_selectModel(requestBody, bedrockAccount) {
|
||||
let selectedModel;
|
||||
|
||||
let selectedModel
|
||||
|
||||
// 优先使用账户配置的模型
|
||||
if (bedrockAccount?.defaultModel) {
|
||||
selectedModel = bedrockAccount.defaultModel;
|
||||
logger.info(`🎯 使用账户配置的模型: ${selectedModel}`, { metadata: { source: 'account', accountId: bedrockAccount.id } });
|
||||
selectedModel = bedrockAccount.defaultModel
|
||||
logger.info(`🎯 使用账户配置的模型: ${selectedModel}`, {
|
||||
metadata: { source: 'account', accountId: bedrockAccount.id }
|
||||
})
|
||||
}
|
||||
// 检查请求中指定的模型
|
||||
else if (requestBody.model) {
|
||||
selectedModel = requestBody.model;
|
||||
logger.info(`🎯 使用请求指定的模型: ${selectedModel}`, { metadata: { source: 'request' } });
|
||||
selectedModel = requestBody.model
|
||||
logger.info(`🎯 使用请求指定的模型: ${selectedModel}`, { metadata: { source: 'request' } })
|
||||
}
|
||||
// 使用默认模型
|
||||
else {
|
||||
selectedModel = this.defaultModel;
|
||||
logger.info(`🎯 使用系统默认模型: ${selectedModel}`, { metadata: { source: 'default' } });
|
||||
selectedModel = this.defaultModel
|
||||
logger.info(`🎯 使用系统默认模型: ${selectedModel}`, { metadata: { source: 'default' } })
|
||||
}
|
||||
|
||||
// 如果是标准Claude模型名,需要映射为Bedrock格式
|
||||
const bedrockModel = this._mapToBedrockModel(selectedModel);
|
||||
const bedrockModel = this._mapToBedrockModel(selectedModel)
|
||||
if (bedrockModel !== selectedModel) {
|
||||
logger.info(`🔄 模型映射: ${selectedModel} → ${bedrockModel}`, { metadata: { originalModel: selectedModel, bedrockModel } });
|
||||
logger.info(`🔄 模型映射: ${selectedModel} → ${bedrockModel}`, {
|
||||
metadata: { originalModel: selectedModel, bedrockModel }
|
||||
})
|
||||
}
|
||||
|
||||
return bedrockModel;
|
||||
return bedrockModel
|
||||
}
|
||||
|
||||
// 将标准Claude模型名映射为Bedrock格式
|
||||
@@ -222,63 +234,65 @@ class BedrockRelayService {
|
||||
// Claude Sonnet 4
|
||||
'claude-sonnet-4': 'us.anthropic.claude-sonnet-4-20250514-v1:0',
|
||||
'claude-sonnet-4-20250514': 'us.anthropic.claude-sonnet-4-20250514-v1:0',
|
||||
|
||||
|
||||
// Claude Opus 4.1
|
||||
'claude-opus-4': 'us.anthropic.claude-opus-4-1-20250805-v1:0',
|
||||
'claude-opus-4-1': 'us.anthropic.claude-opus-4-1-20250805-v1:0',
|
||||
'claude-opus-4-1-20250805': 'us.anthropic.claude-opus-4-1-20250805-v1:0',
|
||||
|
||||
|
||||
// Claude 3.7 Sonnet
|
||||
'claude-3-7-sonnet': 'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
|
||||
'claude-3-7-sonnet-20250219': 'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
|
||||
|
||||
|
||||
// Claude 3.5 Sonnet v2
|
||||
'claude-3-5-sonnet': 'us.anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
'claude-3-5-sonnet-20241022': 'us.anthropic.claude-3-5-sonnet-20241022-v2:0',
|
||||
|
||||
|
||||
// Claude 3.5 Haiku
|
||||
'claude-3-5-haiku': 'us.anthropic.claude-3-5-haiku-20241022-v1:0',
|
||||
'claude-3-5-haiku-20241022': 'us.anthropic.claude-3-5-haiku-20241022-v1:0',
|
||||
|
||||
|
||||
// Claude 3 Sonnet
|
||||
'claude-3-sonnet': 'us.anthropic.claude-3-sonnet-20240229-v1:0',
|
||||
'claude-3-sonnet-20240229': 'us.anthropic.claude-3-sonnet-20240229-v1:0',
|
||||
|
||||
|
||||
// Claude 3 Haiku
|
||||
'claude-3-haiku': 'us.anthropic.claude-3-haiku-20240307-v1:0',
|
||||
'claude-3-haiku-20240307': 'us.anthropic.claude-3-haiku-20240307-v1:0'
|
||||
};
|
||||
}
|
||||
|
||||
// 如果已经是Bedrock格式,直接返回
|
||||
// Bedrock模型格式:{region}.anthropic.{model-name} 或 anthropic.{model-name}
|
||||
if (modelName.includes('.anthropic.') || modelName.startsWith('anthropic.')) {
|
||||
return modelName;
|
||||
return modelName
|
||||
}
|
||||
|
||||
// 查找映射
|
||||
const mappedModel = modelMapping[modelName];
|
||||
const mappedModel = modelMapping[modelName]
|
||||
if (mappedModel) {
|
||||
return mappedModel;
|
||||
return mappedModel
|
||||
}
|
||||
|
||||
// 如果没有找到映射,返回原始模型名(可能会导致错误,但保持向后兼容)
|
||||
logger.warn(`⚠️ 未找到模型映射: ${modelName},使用原始模型名`, { metadata: { originalModel: modelName } });
|
||||
return modelName;
|
||||
logger.warn(`⚠️ 未找到模型映射: ${modelName},使用原始模型名`, {
|
||||
metadata: { originalModel: modelName }
|
||||
})
|
||||
return modelName
|
||||
}
|
||||
|
||||
// 选择使用的区域
|
||||
_selectRegion(modelId, bedrockAccount) {
|
||||
// 优先使用账户配置的区域
|
||||
if (bedrockAccount?.region) {
|
||||
return bedrockAccount.region;
|
||||
return bedrockAccount.region
|
||||
}
|
||||
|
||||
// 对于小模型,使用专门的区域配置
|
||||
if (modelId.includes('haiku')) {
|
||||
return this.smallFastModelRegion;
|
||||
return this.smallFastModelRegion
|
||||
}
|
||||
|
||||
return this.defaultRegion;
|
||||
return this.defaultRegion
|
||||
}
|
||||
|
||||
// 转换Claude格式请求到Bedrock格式
|
||||
@@ -287,40 +301,40 @@ class BedrockRelayService {
|
||||
anthropic_version: 'bedrock-2023-05-31',
|
||||
max_tokens: Math.min(requestBody.max_tokens || this.maxOutputTokens, this.maxOutputTokens),
|
||||
messages: requestBody.messages || []
|
||||
};
|
||||
}
|
||||
|
||||
// 添加系统提示词
|
||||
if (requestBody.system) {
|
||||
bedrockPayload.system = requestBody.system;
|
||||
bedrockPayload.system = requestBody.system
|
||||
}
|
||||
|
||||
// 添加其他参数
|
||||
if (requestBody.temperature !== undefined) {
|
||||
bedrockPayload.temperature = requestBody.temperature;
|
||||
bedrockPayload.temperature = requestBody.temperature
|
||||
}
|
||||
|
||||
if (requestBody.top_p !== undefined) {
|
||||
bedrockPayload.top_p = requestBody.top_p;
|
||||
bedrockPayload.top_p = requestBody.top_p
|
||||
}
|
||||
|
||||
if (requestBody.top_k !== undefined) {
|
||||
bedrockPayload.top_k = requestBody.top_k;
|
||||
bedrockPayload.top_k = requestBody.top_k
|
||||
}
|
||||
|
||||
if (requestBody.stop_sequences) {
|
||||
bedrockPayload.stop_sequences = requestBody.stop_sequences;
|
||||
bedrockPayload.stop_sequences = requestBody.stop_sequences
|
||||
}
|
||||
|
||||
// 工具调用支持
|
||||
if (requestBody.tools) {
|
||||
bedrockPayload.tools = requestBody.tools;
|
||||
bedrockPayload.tools = requestBody.tools
|
||||
}
|
||||
|
||||
if (requestBody.tool_choice) {
|
||||
bedrockPayload.tool_choice = requestBody.tool_choice;
|
||||
bedrockPayload.tool_choice = requestBody.tool_choice
|
||||
}
|
||||
|
||||
return bedrockPayload;
|
||||
return bedrockPayload
|
||||
}
|
||||
|
||||
// 转换Bedrock响应到Claude格式
|
||||
@@ -337,7 +351,7 @@ class BedrockRelayService {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 转换Bedrock流事件到Claude SSE格式
|
||||
@@ -355,7 +369,7 @@ class BedrockRelayService {
|
||||
stop_sequence: null,
|
||||
usage: bedrockChunk.message?.usage || { input_tokens: 0, output_tokens: 0 }
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (bedrockChunk.type === 'content_block_delta') {
|
||||
@@ -365,7 +379,7 @@ class BedrockRelayService {
|
||||
index: bedrockChunk.index || 0,
|
||||
delta: bedrockChunk.delta || {}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (bedrockChunk.type === 'message_delta') {
|
||||
@@ -375,7 +389,7 @@ class BedrockRelayService {
|
||||
delta: bedrockChunk.delta || {},
|
||||
usage: bedrockChunk.usage || {}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (bedrockChunk.type === 'message_stop') {
|
||||
@@ -384,39 +398,39 @@ class BedrockRelayService {
|
||||
data: {
|
||||
usage: bedrockChunk.usage || {}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
// 处理Bedrock错误
|
||||
_handleBedrockError(error) {
|
||||
const errorMessage = error.message || 'Unknown Bedrock error';
|
||||
const errorMessage = error.message || 'Unknown Bedrock error'
|
||||
|
||||
if (error.name === 'ValidationException') {
|
||||
return new Error(`Bedrock参数验证失败: ${errorMessage}`);
|
||||
return new Error(`Bedrock参数验证失败: ${errorMessage}`)
|
||||
}
|
||||
|
||||
if (error.name === 'ThrottlingException') {
|
||||
return new Error('Bedrock请求限流,请稍后重试');
|
||||
return new Error('Bedrock请求限流,请稍后重试')
|
||||
}
|
||||
|
||||
if (error.name === 'AccessDeniedException') {
|
||||
return new Error('Bedrock访问被拒绝,请检查IAM权限');
|
||||
return new Error('Bedrock访问被拒绝,请检查IAM权限')
|
||||
}
|
||||
|
||||
if (error.name === 'ModelNotReadyException') {
|
||||
return new Error('Bedrock模型未就绪,请稍后重试');
|
||||
return new Error('Bedrock模型未就绪,请稍后重试')
|
||||
}
|
||||
|
||||
return new Error(`Bedrock服务错误: ${errorMessage}`);
|
||||
return new Error(`Bedrock服务错误: ${errorMessage}`)
|
||||
}
|
||||
|
||||
// 获取可用模型列表
|
||||
async getAvailableModels(bedrockAccount = null) {
|
||||
try {
|
||||
const region = bedrockAccount?.region || this.defaultRegion;
|
||||
const region = bedrockAccount?.region || this.defaultRegion
|
||||
|
||||
// Bedrock暂不支持列出推理配置文件的API,返回预定义的模型列表
|
||||
const models = [
|
||||
@@ -450,16 +464,15 @@ class BedrockRelayService {
|
||||
provider: 'anthropic',
|
||||
type: 'bedrock'
|
||||
}
|
||||
];
|
||||
|
||||
logger.debug(`📋 返回Bedrock可用模型 ${models.length} 个, 区域: ${region}`);
|
||||
return models;
|
||||
]
|
||||
|
||||
logger.debug(`📋 返回Bedrock可用模型 ${models.length} 个, 区域: ${region}`)
|
||||
return models
|
||||
} catch (error) {
|
||||
logger.error('❌ 获取Bedrock模型列表失败:', error);
|
||||
return [];
|
||||
logger.error('❌ 获取Bedrock模型列表失败:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new BedrockRelayService();
|
||||
module.exports = new BedrockRelayService()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,8 +3,8 @@
|
||||
* 负责存储和管理不同账号使用的 Claude Code headers
|
||||
*/
|
||||
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class ClaudeCodeHeadersService {
|
||||
constructor() {
|
||||
@@ -22,8 +22,8 @@ class ClaudeCodeHeadersService {
|
||||
'user-agent': 'claude-cli/1.0.57 (external, cli)',
|
||||
'accept-language': '*',
|
||||
'sec-fetch-mode': 'cors'
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 需要捕获的 Claude Code 特定 headers
|
||||
this.claudeCodeHeaderKeys = [
|
||||
'x-stainless-retry-count',
|
||||
@@ -40,16 +40,18 @@ class ClaudeCodeHeadersService {
|
||||
'accept-language',
|
||||
'sec-fetch-mode',
|
||||
'accept-encoding'
|
||||
];
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 user-agent 中提取版本号
|
||||
*/
|
||||
extractVersionFromUserAgent(userAgent) {
|
||||
if (!userAgent) return null;
|
||||
const match = userAgent.match(/claude-cli\/(\d+\.\d+\.\d+)/);
|
||||
return match ? match[1] : null;
|
||||
if (!userAgent) {
|
||||
return null
|
||||
}
|
||||
const match = userAgent.match(/claude-cli\/(\d+\.\d+\.\d+)/)
|
||||
return match ? match[1] : null
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -57,43 +59,49 @@ class ClaudeCodeHeadersService {
|
||||
* @returns {number} 1 if v1 > v2, -1 if v1 < v2, 0 if equal
|
||||
*/
|
||||
compareVersions(v1, v2) {
|
||||
if (!v1 || !v2) return 0;
|
||||
|
||||
const parts1 = v1.split('.').map(Number);
|
||||
const parts2 = v2.split('.').map(Number);
|
||||
|
||||
for (let i = 0; i < Math.max(parts1.length, parts2.length); i++) {
|
||||
const p1 = parts1[i] || 0;
|
||||
const p2 = parts2[i] || 0;
|
||||
|
||||
if (p1 > p2) return 1;
|
||||
if (p1 < p2) return -1;
|
||||
if (!v1 || !v2) {
|
||||
return 0
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
const parts1 = v1.split('.').map(Number)
|
||||
const parts2 = v2.split('.').map(Number)
|
||||
|
||||
for (let i = 0; i < Math.max(parts1.length, parts2.length); i++) {
|
||||
const p1 = parts1[i] || 0
|
||||
const p2 = parts2[i] || 0
|
||||
|
||||
if (p1 > p2) {
|
||||
return 1
|
||||
}
|
||||
if (p1 < p2) {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 从客户端 headers 中提取 Claude Code 相关的 headers
|
||||
*/
|
||||
extractClaudeCodeHeaders(clientHeaders) {
|
||||
const headers = {};
|
||||
|
||||
const headers = {}
|
||||
|
||||
// 转换所有 header keys 为小写进行比较
|
||||
const lowerCaseHeaders = {};
|
||||
Object.keys(clientHeaders || {}).forEach(key => {
|
||||
lowerCaseHeaders[key.toLowerCase()] = clientHeaders[key];
|
||||
});
|
||||
|
||||
const lowerCaseHeaders = {}
|
||||
Object.keys(clientHeaders || {}).forEach((key) => {
|
||||
lowerCaseHeaders[key.toLowerCase()] = clientHeaders[key]
|
||||
})
|
||||
|
||||
// 提取需要的 headers
|
||||
this.claudeCodeHeaderKeys.forEach(key => {
|
||||
const lowerKey = key.toLowerCase();
|
||||
this.claudeCodeHeaderKeys.forEach((key) => {
|
||||
const lowerKey = key.toLowerCase()
|
||||
if (lowerCaseHeaders[lowerKey]) {
|
||||
headers[key] = lowerCaseHeaders[lowerKey];
|
||||
headers[key] = lowerCaseHeaders[lowerKey]
|
||||
}
|
||||
});
|
||||
|
||||
return headers;
|
||||
})
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -101,48 +109,47 @@ class ClaudeCodeHeadersService {
|
||||
*/
|
||||
async storeAccountHeaders(accountId, clientHeaders) {
|
||||
try {
|
||||
const extractedHeaders = this.extractClaudeCodeHeaders(clientHeaders);
|
||||
|
||||
const extractedHeaders = this.extractClaudeCodeHeaders(clientHeaders)
|
||||
|
||||
// 检查是否有 user-agent
|
||||
const userAgent = extractedHeaders['user-agent'];
|
||||
const userAgent = extractedHeaders['user-agent']
|
||||
if (!userAgent || !userAgent.includes('claude-cli')) {
|
||||
// 不是 Claude Code 的请求,不存储
|
||||
return;
|
||||
return
|
||||
}
|
||||
|
||||
const version = this.extractVersionFromUserAgent(userAgent);
|
||||
|
||||
const version = this.extractVersionFromUserAgent(userAgent)
|
||||
if (!version) {
|
||||
logger.warn(`⚠️ Failed to extract version from user-agent: ${userAgent}`);
|
||||
return;
|
||||
logger.warn(`⚠️ Failed to extract version from user-agent: ${userAgent}`)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 获取当前存储的 headers
|
||||
const key = `claude_code_headers:${accountId}`;
|
||||
const currentData = await redis.getClient().get(key);
|
||||
|
||||
const key = `claude_code_headers:${accountId}`
|
||||
const currentData = await redis.getClient().get(key)
|
||||
|
||||
if (currentData) {
|
||||
const current = JSON.parse(currentData);
|
||||
const currentVersion = this.extractVersionFromUserAgent(current.headers['user-agent']);
|
||||
|
||||
const current = JSON.parse(currentData)
|
||||
const currentVersion = this.extractVersionFromUserAgent(current.headers['user-agent'])
|
||||
|
||||
// 只有新版本更高时才更新
|
||||
if (this.compareVersions(version, currentVersion) <= 0) {
|
||||
return;
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 存储新的 headers
|
||||
const data = {
|
||||
headers: extractedHeaders,
|
||||
version: version,
|
||||
version,
|
||||
updatedAt: new Date().toISOString()
|
||||
};
|
||||
|
||||
await redis.getClient().setex(key, 86400 * 7, JSON.stringify(data)); // 7天过期
|
||||
|
||||
logger.info(`✅ Stored Claude Code headers for account ${accountId}, version: ${version}`);
|
||||
|
||||
}
|
||||
|
||||
await redis.getClient().setex(key, 86400 * 7, JSON.stringify(data)) // 7天过期
|
||||
|
||||
logger.info(`✅ Stored Claude Code headers for account ${accountId}, version: ${version}`)
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to store Claude Code headers for account ${accountId}:`, error);
|
||||
logger.error(`❌ Failed to store Claude Code headers for account ${accountId}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,22 +158,23 @@ class ClaudeCodeHeadersService {
|
||||
*/
|
||||
async getAccountHeaders(accountId) {
|
||||
try {
|
||||
const key = `claude_code_headers:${accountId}`;
|
||||
const data = await redis.getClient().get(key);
|
||||
|
||||
const key = `claude_code_headers:${accountId}`
|
||||
const data = await redis.getClient().get(key)
|
||||
|
||||
if (data) {
|
||||
const parsed = JSON.parse(data);
|
||||
logger.debug(`📋 Retrieved Claude Code headers for account ${accountId}, version: ${parsed.version}`);
|
||||
return parsed.headers;
|
||||
const parsed = JSON.parse(data)
|
||||
logger.debug(
|
||||
`📋 Retrieved Claude Code headers for account ${accountId}, version: ${parsed.version}`
|
||||
)
|
||||
return parsed.headers
|
||||
}
|
||||
|
||||
|
||||
// 返回默认 headers
|
||||
logger.debug(`📋 Using default Claude Code headers for account ${accountId}`);
|
||||
return this.defaultHeaders;
|
||||
|
||||
logger.debug(`📋 Using default Claude Code headers for account ${accountId}`)
|
||||
return this.defaultHeaders
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to get Claude Code headers for account ${accountId}:`, error);
|
||||
return this.defaultHeaders;
|
||||
logger.error(`❌ Failed to get Claude Code headers for account ${accountId}:`, error)
|
||||
return this.defaultHeaders
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,11 +183,11 @@ class ClaudeCodeHeadersService {
|
||||
*/
|
||||
async clearAccountHeaders(accountId) {
|
||||
try {
|
||||
const key = `claude_code_headers:${accountId}`;
|
||||
await redis.getClient().del(key);
|
||||
logger.info(`🗑️ Cleared Claude Code headers for account ${accountId}`);
|
||||
const key = `claude_code_headers:${accountId}`
|
||||
await redis.getClient().del(key)
|
||||
logger.info(`🗑️ Cleared Claude Code headers for account ${accountId}`)
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to clear Claude Code headers for account ${accountId}:`, error);
|
||||
logger.error(`❌ Failed to clear Claude Code headers for account ${accountId}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,25 +196,24 @@ class ClaudeCodeHeadersService {
|
||||
*/
|
||||
async getAllAccountHeaders() {
|
||||
try {
|
||||
const pattern = 'claude_code_headers:*';
|
||||
const keys = await redis.getClient().keys(pattern);
|
||||
|
||||
const results = {};
|
||||
const pattern = 'claude_code_headers:*'
|
||||
const keys = await redis.getClient().keys(pattern)
|
||||
|
||||
const results = {}
|
||||
for (const key of keys) {
|
||||
const accountId = key.replace('claude_code_headers:', '');
|
||||
const data = await redis.getClient().get(key);
|
||||
const accountId = key.replace('claude_code_headers:', '')
|
||||
const data = await redis.getClient().get(key)
|
||||
if (data) {
|
||||
results[accountId] = JSON.parse(data);
|
||||
results[accountId] = JSON.parse(data)
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
|
||||
|
||||
return results
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get all account headers:', error);
|
||||
return {};
|
||||
logger.error('❌ Failed to get all account headers:', error)
|
||||
return {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new ClaudeCodeHeadersService();
|
||||
module.exports = new ClaudeCodeHeadersService()
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const crypto = require('crypto');
|
||||
const { SocksProxyAgent } = require('socks-proxy-agent');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const { v4: uuidv4 } = require('uuid')
|
||||
const crypto = require('crypto')
|
||||
const { SocksProxyAgent } = require('socks-proxy-agent')
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
|
||||
class ClaudeConsoleAccountService {
|
||||
constructor() {
|
||||
// 加密相关常量
|
||||
this.ENCRYPTION_ALGORITHM = 'aes-256-cbc';
|
||||
this.ENCRYPTION_SALT = 'claude-console-salt';
|
||||
|
||||
this.ENCRYPTION_ALGORITHM = 'aes-256-cbc'
|
||||
this.ENCRYPTION_SALT = 'claude-console-salt'
|
||||
|
||||
// Redis键前缀
|
||||
this.ACCOUNT_KEY_PREFIX = 'claude_console_account:';
|
||||
this.SHARED_ACCOUNTS_KEY = 'shared_claude_console_accounts';
|
||||
this.ACCOUNT_KEY_PREFIX = 'claude_console_account:'
|
||||
this.SHARED_ACCOUNTS_KEY = 'shared_claude_console_accounts'
|
||||
}
|
||||
|
||||
// 🏢 创建Claude Console账户
|
||||
@@ -32,24 +32,24 @@ class ClaudeConsoleAccountService {
|
||||
isActive = true,
|
||||
accountType = 'shared', // 'dedicated' or 'shared'
|
||||
schedulable = true // 是否可被调度
|
||||
} = options;
|
||||
} = options
|
||||
|
||||
// 验证必填字段
|
||||
if (!apiUrl || !apiKey) {
|
||||
throw new Error('API URL and API Key are required for Claude Console account');
|
||||
throw new Error('API URL and API Key are required for Claude Console account')
|
||||
}
|
||||
|
||||
const accountId = uuidv4();
|
||||
|
||||
const accountId = uuidv4()
|
||||
|
||||
// 处理 supportedModels,确保向后兼容
|
||||
const processedModels = this._processModelMapping(supportedModels);
|
||||
|
||||
const processedModels = this._processModelMapping(supportedModels)
|
||||
|
||||
const accountData = {
|
||||
id: accountId,
|
||||
platform: 'claude-console',
|
||||
name,
|
||||
description,
|
||||
apiUrl: apiUrl,
|
||||
apiUrl,
|
||||
apiKey: this._encryptSensitiveData(apiKey),
|
||||
priority: priority.toString(),
|
||||
supportedModels: JSON.stringify(processedModels),
|
||||
@@ -67,24 +67,23 @@ class ClaudeConsoleAccountService {
|
||||
rateLimitStatus: '',
|
||||
// 调度控制
|
||||
schedulable: schedulable.toString()
|
||||
};
|
||||
}
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
logger.debug(
|
||||
`[DEBUG] Saving account data to Redis with key: ${this.ACCOUNT_KEY_PREFIX}${accountId}`
|
||||
)
|
||||
logger.debug(`[DEBUG] Account data to save: ${JSON.stringify(accountData, null, 2)}`)
|
||||
|
||||
await client.hset(`${this.ACCOUNT_KEY_PREFIX}${accountId}`, accountData)
|
||||
|
||||
const client = redis.getClientSafe();
|
||||
logger.debug(`[DEBUG] Saving account data to Redis with key: ${this.ACCOUNT_KEY_PREFIX}${accountId}`);
|
||||
logger.debug(`[DEBUG] Account data to save: ${JSON.stringify(accountData, null, 2)}`);
|
||||
|
||||
await client.hset(
|
||||
`${this.ACCOUNT_KEY_PREFIX}${accountId}`,
|
||||
accountData
|
||||
);
|
||||
|
||||
// 如果是共享账户,添加到共享账户集合
|
||||
if (accountType === 'shared') {
|
||||
await client.sadd(this.SHARED_ACCOUNTS_KEY, accountId);
|
||||
await client.sadd(this.SHARED_ACCOUNTS_KEY, accountId)
|
||||
}
|
||||
|
||||
logger.success(`🏢 Created Claude Console account: ${name} (${accountId})`);
|
||||
|
||||
|
||||
logger.success(`🏢 Created Claude Console account: ${name} (${accountId})`)
|
||||
|
||||
return {
|
||||
id: accountId,
|
||||
name,
|
||||
@@ -99,22 +98,22 @@ class ClaudeConsoleAccountService {
|
||||
accountType,
|
||||
status: 'active',
|
||||
createdAt: accountData.createdAt
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 📋 获取所有Claude Console账户
|
||||
async getAllAccounts() {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const keys = await client.keys(`${this.ACCOUNT_KEY_PREFIX}*`);
|
||||
const accounts = [];
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const keys = await client.keys(`${this.ACCOUNT_KEY_PREFIX}*`)
|
||||
const accounts = []
|
||||
|
||||
for (const key of keys) {
|
||||
const accountData = await client.hgetall(key);
|
||||
const accountData = await client.hgetall(key)
|
||||
if (accountData && Object.keys(accountData).length > 0) {
|
||||
// 获取限流状态信息
|
||||
const rateLimitInfo = this._getRateLimitInfo(accountData);
|
||||
|
||||
const rateLimitInfo = this._getRateLimitInfo(accountData)
|
||||
|
||||
accounts.push({
|
||||
id: accountData.id,
|
||||
platform: accountData.platform,
|
||||
@@ -134,356 +133,379 @@ class ClaudeConsoleAccountService {
|
||||
lastUsedAt: accountData.lastUsedAt,
|
||||
rateLimitStatus: rateLimitInfo,
|
||||
schedulable: accountData.schedulable !== 'false' // 默认为true,只有明确设置为false才不可调度
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return accounts;
|
||||
|
||||
return accounts
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to get Claude Console accounts:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to get Claude Console accounts:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 获取单个账户(内部使用,包含敏感信息)
|
||||
async getAccount(accountId) {
|
||||
const client = redis.getClientSafe();
|
||||
logger.debug(`[DEBUG] Getting account data for ID: ${accountId}`);
|
||||
const accountData = await client.hgetall(`${this.ACCOUNT_KEY_PREFIX}${accountId}`);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
logger.debug(`[DEBUG] Getting account data for ID: ${accountId}`)
|
||||
const accountData = await client.hgetall(`${this.ACCOUNT_KEY_PREFIX}${accountId}`)
|
||||
|
||||
if (!accountData || Object.keys(accountData).length === 0) {
|
||||
logger.debug(`[DEBUG] No account data found for ID: ${accountId}`);
|
||||
return null;
|
||||
logger.debug(`[DEBUG] No account data found for ID: ${accountId}`)
|
||||
return null
|
||||
}
|
||||
|
||||
logger.debug(`[DEBUG] Raw account data keys: ${Object.keys(accountData).join(', ')}`);
|
||||
logger.debug(`[DEBUG] Raw supportedModels value: ${accountData.supportedModels}`);
|
||||
|
||||
|
||||
logger.debug(`[DEBUG] Raw account data keys: ${Object.keys(accountData).join(', ')}`)
|
||||
logger.debug(`[DEBUG] Raw supportedModels value: ${accountData.supportedModels}`)
|
||||
|
||||
// 解密敏感字段(只解密apiKey,apiUrl不加密)
|
||||
const decryptedKey = this._decryptSensitiveData(accountData.apiKey);
|
||||
logger.debug(`[DEBUG] URL exists: ${!!accountData.apiUrl}, Decrypted key exists: ${!!decryptedKey}`);
|
||||
|
||||
accountData.apiKey = decryptedKey;
|
||||
|
||||
const decryptedKey = this._decryptSensitiveData(accountData.apiKey)
|
||||
logger.debug(
|
||||
`[DEBUG] URL exists: ${!!accountData.apiUrl}, Decrypted key exists: ${!!decryptedKey}`
|
||||
)
|
||||
|
||||
accountData.apiKey = decryptedKey
|
||||
|
||||
// 解析JSON字段
|
||||
const parsedModels = JSON.parse(accountData.supportedModels || '[]');
|
||||
logger.debug(`[DEBUG] Parsed supportedModels: ${JSON.stringify(parsedModels)}`);
|
||||
|
||||
accountData.supportedModels = parsedModels;
|
||||
accountData.priority = parseInt(accountData.priority) || 50;
|
||||
accountData.rateLimitDuration = parseInt(accountData.rateLimitDuration) || 60;
|
||||
accountData.isActive = accountData.isActive === 'true';
|
||||
accountData.schedulable = accountData.schedulable !== 'false'; // 默认为true
|
||||
|
||||
const parsedModels = JSON.parse(accountData.supportedModels || '[]')
|
||||
logger.debug(`[DEBUG] Parsed supportedModels: ${JSON.stringify(parsedModels)}`)
|
||||
|
||||
accountData.supportedModels = parsedModels
|
||||
accountData.priority = parseInt(accountData.priority) || 50
|
||||
accountData.rateLimitDuration = parseInt(accountData.rateLimitDuration) || 60
|
||||
accountData.isActive = accountData.isActive === 'true'
|
||||
accountData.schedulable = accountData.schedulable !== 'false' // 默认为true
|
||||
|
||||
if (accountData.proxy) {
|
||||
accountData.proxy = JSON.parse(accountData.proxy);
|
||||
accountData.proxy = JSON.parse(accountData.proxy)
|
||||
}
|
||||
|
||||
logger.debug(`[DEBUG] Final account data - name: ${accountData.name}, hasApiUrl: ${!!accountData.apiUrl}, hasApiKey: ${!!accountData.apiKey}, supportedModels: ${JSON.stringify(accountData.supportedModels)}`);
|
||||
|
||||
return accountData;
|
||||
|
||||
logger.debug(
|
||||
`[DEBUG] Final account data - name: ${accountData.name}, hasApiUrl: ${!!accountData.apiUrl}, hasApiKey: ${!!accountData.apiKey}, supportedModels: ${JSON.stringify(accountData.supportedModels)}`
|
||||
)
|
||||
|
||||
return accountData
|
||||
}
|
||||
|
||||
// 📝 更新账户
|
||||
async updateAccount(accountId, updates) {
|
||||
try {
|
||||
const existingAccount = await this.getAccount(accountId);
|
||||
const existingAccount = await this.getAccount(accountId)
|
||||
if (!existingAccount) {
|
||||
throw new Error('Account not found');
|
||||
throw new Error('Account not found')
|
||||
}
|
||||
|
||||
const client = redis.getClientSafe();
|
||||
const updatedData = {};
|
||||
const client = redis.getClientSafe()
|
||||
const updatedData = {}
|
||||
|
||||
// 处理各个字段的更新
|
||||
logger.debug(`[DEBUG] Update request received with fields: ${Object.keys(updates).join(', ')}`);
|
||||
logger.debug(`[DEBUG] Updates content: ${JSON.stringify(updates, null, 2)}`);
|
||||
|
||||
if (updates.name !== undefined) updatedData.name = updates.name;
|
||||
if (updates.description !== undefined) updatedData.description = updates.description;
|
||||
logger.debug(
|
||||
`[DEBUG] Update request received with fields: ${Object.keys(updates).join(', ')}`
|
||||
)
|
||||
logger.debug(`[DEBUG] Updates content: ${JSON.stringify(updates, null, 2)}`)
|
||||
|
||||
if (updates.name !== undefined) {
|
||||
updatedData.name = updates.name
|
||||
}
|
||||
if (updates.description !== undefined) {
|
||||
updatedData.description = updates.description
|
||||
}
|
||||
if (updates.apiUrl !== undefined) {
|
||||
logger.debug(`[DEBUG] Updating apiUrl from frontend: ${updates.apiUrl}`);
|
||||
updatedData.apiUrl = updates.apiUrl;
|
||||
logger.debug(`[DEBUG] Updating apiUrl from frontend: ${updates.apiUrl}`)
|
||||
updatedData.apiUrl = updates.apiUrl
|
||||
}
|
||||
if (updates.apiKey !== undefined) {
|
||||
logger.debug(`[DEBUG] Updating apiKey (length: ${updates.apiKey?.length})`);
|
||||
updatedData.apiKey = this._encryptSensitiveData(updates.apiKey);
|
||||
logger.debug(`[DEBUG] Updating apiKey (length: ${updates.apiKey?.length})`)
|
||||
updatedData.apiKey = this._encryptSensitiveData(updates.apiKey)
|
||||
}
|
||||
if (updates.priority !== undefined) {
|
||||
updatedData.priority = updates.priority.toString()
|
||||
}
|
||||
if (updates.priority !== undefined) updatedData.priority = updates.priority.toString();
|
||||
if (updates.supportedModels !== undefined) {
|
||||
logger.debug(`[DEBUG] Updating supportedModels: ${JSON.stringify(updates.supportedModels)}`);
|
||||
logger.debug(`[DEBUG] Updating supportedModels: ${JSON.stringify(updates.supportedModels)}`)
|
||||
// 处理 supportedModels,确保向后兼容
|
||||
const processedModels = this._processModelMapping(updates.supportedModels);
|
||||
updatedData.supportedModels = JSON.stringify(processedModels);
|
||||
const processedModels = this._processModelMapping(updates.supportedModels)
|
||||
updatedData.supportedModels = JSON.stringify(processedModels)
|
||||
}
|
||||
if (updates.userAgent !== undefined) {
|
||||
updatedData.userAgent = updates.userAgent
|
||||
}
|
||||
if (updates.rateLimitDuration !== undefined) {
|
||||
updatedData.rateLimitDuration = updates.rateLimitDuration.toString()
|
||||
}
|
||||
if (updates.proxy !== undefined) {
|
||||
updatedData.proxy = updates.proxy ? JSON.stringify(updates.proxy) : ''
|
||||
}
|
||||
if (updates.isActive !== undefined) {
|
||||
updatedData.isActive = updates.isActive.toString()
|
||||
}
|
||||
if (updates.schedulable !== undefined) {
|
||||
updatedData.schedulable = updates.schedulable.toString()
|
||||
}
|
||||
if (updates.userAgent !== undefined) updatedData.userAgent = updates.userAgent;
|
||||
if (updates.rateLimitDuration !== undefined) updatedData.rateLimitDuration = updates.rateLimitDuration.toString();
|
||||
if (updates.proxy !== undefined) updatedData.proxy = updates.proxy ? JSON.stringify(updates.proxy) : '';
|
||||
if (updates.isActive !== undefined) updatedData.isActive = updates.isActive.toString();
|
||||
if (updates.schedulable !== undefined) updatedData.schedulable = updates.schedulable.toString();
|
||||
|
||||
// 处理账户类型变更
|
||||
if (updates.accountType && updates.accountType !== existingAccount.accountType) {
|
||||
updatedData.accountType = updates.accountType;
|
||||
|
||||
updatedData.accountType = updates.accountType
|
||||
|
||||
if (updates.accountType === 'shared') {
|
||||
await client.sadd(this.SHARED_ACCOUNTS_KEY, accountId);
|
||||
await client.sadd(this.SHARED_ACCOUNTS_KEY, accountId)
|
||||
} else {
|
||||
await client.srem(this.SHARED_ACCOUNTS_KEY, accountId);
|
||||
await client.srem(this.SHARED_ACCOUNTS_KEY, accountId)
|
||||
}
|
||||
}
|
||||
|
||||
updatedData.updatedAt = new Date().toISOString();
|
||||
|
||||
logger.debug(`[DEBUG] Final updatedData to save: ${JSON.stringify(updatedData, null, 2)}`);
|
||||
logger.debug(`[DEBUG] Updating Redis key: ${this.ACCOUNT_KEY_PREFIX}${accountId}`);
|
||||
|
||||
await client.hset(
|
||||
`${this.ACCOUNT_KEY_PREFIX}${accountId}`,
|
||||
updatedData
|
||||
);
|
||||
|
||||
logger.success(`📝 Updated Claude Console account: ${accountId}`);
|
||||
|
||||
return { success: true };
|
||||
updatedData.updatedAt = new Date().toISOString()
|
||||
|
||||
logger.debug(`[DEBUG] Final updatedData to save: ${JSON.stringify(updatedData, null, 2)}`)
|
||||
logger.debug(`[DEBUG] Updating Redis key: ${this.ACCOUNT_KEY_PREFIX}${accountId}`)
|
||||
|
||||
await client.hset(`${this.ACCOUNT_KEY_PREFIX}${accountId}`, updatedData)
|
||||
|
||||
logger.success(`📝 Updated Claude Console account: ${accountId}`)
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to update Claude Console account:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to update Claude Console account:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🗑️ 删除账户
|
||||
async deleteAccount(accountId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const account = await this.getAccount(accountId);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const account = await this.getAccount(accountId)
|
||||
|
||||
if (!account) {
|
||||
throw new Error('Account not found');
|
||||
throw new Error('Account not found')
|
||||
}
|
||||
|
||||
|
||||
// 从Redis删除
|
||||
await client.del(`${this.ACCOUNT_KEY_PREFIX}${accountId}`);
|
||||
|
||||
await client.del(`${this.ACCOUNT_KEY_PREFIX}${accountId}`)
|
||||
|
||||
// 从共享账户集合中移除
|
||||
if (account.accountType === 'shared') {
|
||||
await client.srem(this.SHARED_ACCOUNTS_KEY, accountId);
|
||||
await client.srem(this.SHARED_ACCOUNTS_KEY, accountId)
|
||||
}
|
||||
|
||||
logger.success(`🗑️ Deleted Claude Console account: ${accountId}`);
|
||||
|
||||
return { success: true };
|
||||
|
||||
logger.success(`🗑️ Deleted Claude Console account: ${accountId}`)
|
||||
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to delete Claude Console account:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to delete Claude Console account:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 🚫 标记账号为限流状态
|
||||
async markAccountRateLimited(accountId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const account = await this.getAccount(accountId);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const account = await this.getAccount(accountId)
|
||||
|
||||
if (!account) {
|
||||
throw new Error('Account not found');
|
||||
throw new Error('Account not found')
|
||||
}
|
||||
|
||||
const updates = {
|
||||
rateLimitedAt: new Date().toISOString(),
|
||||
rateLimitStatus: 'limited'
|
||||
};
|
||||
}
|
||||
|
||||
await client.hset(
|
||||
`${this.ACCOUNT_KEY_PREFIX}${accountId}`,
|
||||
updates
|
||||
);
|
||||
await client.hset(`${this.ACCOUNT_KEY_PREFIX}${accountId}`, updates)
|
||||
|
||||
logger.warn(`🚫 Claude Console account marked as rate limited: ${account.name} (${accountId})`);
|
||||
return { success: true };
|
||||
logger.warn(
|
||||
`🚫 Claude Console account marked as rate limited: ${account.name} (${accountId})`
|
||||
)
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to mark Claude Console account as rate limited: ${accountId}`, error);
|
||||
throw error;
|
||||
logger.error(`❌ Failed to mark Claude Console account as rate limited: ${accountId}`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// ✅ 移除账号的限流状态
|
||||
async removeAccountRateLimit(accountId) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
await client.hdel(
|
||||
`${this.ACCOUNT_KEY_PREFIX}${accountId}`,
|
||||
'rateLimitedAt',
|
||||
'rateLimitStatus'
|
||||
);
|
||||
)
|
||||
|
||||
logger.success(`✅ Rate limit removed for Claude Console account: ${accountId}`);
|
||||
return { success: true };
|
||||
logger.success(`✅ Rate limit removed for Claude Console account: ${accountId}`)
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to remove rate limit for Claude Console account: ${accountId}`, error);
|
||||
throw error;
|
||||
logger.error(`❌ Failed to remove rate limit for Claude Console account: ${accountId}`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 检查账号是否处于限流状态
|
||||
async isAccountRateLimited(accountId) {
|
||||
try {
|
||||
const account = await this.getAccount(accountId);
|
||||
const account = await this.getAccount(accountId)
|
||||
if (!account) {
|
||||
return false;
|
||||
return false
|
||||
}
|
||||
|
||||
if (account.rateLimitStatus === 'limited' && account.rateLimitedAt) {
|
||||
const rateLimitedAt = new Date(account.rateLimitedAt);
|
||||
const now = new Date();
|
||||
const minutesSinceRateLimit = (now - rateLimitedAt) / (1000 * 60);
|
||||
const rateLimitedAt = new Date(account.rateLimitedAt)
|
||||
const now = new Date()
|
||||
const minutesSinceRateLimit = (now - rateLimitedAt) / (1000 * 60)
|
||||
|
||||
// 使用账户配置的限流时间
|
||||
const rateLimitDuration = account.rateLimitDuration || 60;
|
||||
|
||||
const rateLimitDuration = account.rateLimitDuration || 60
|
||||
|
||||
if (minutesSinceRateLimit >= rateLimitDuration) {
|
||||
await this.removeAccountRateLimit(accountId);
|
||||
return false;
|
||||
await this.removeAccountRateLimit(accountId)
|
||||
return false
|
||||
}
|
||||
|
||||
return true;
|
||||
return true
|
||||
}
|
||||
|
||||
return false;
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to check rate limit status for Claude Console account: ${accountId}`, error);
|
||||
return false;
|
||||
logger.error(
|
||||
`❌ Failed to check rate limit status for Claude Console account: ${accountId}`,
|
||||
error
|
||||
)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 🚫 标记账号为封锁状态(模型不支持等原因)
|
||||
async blockAccount(accountId, reason) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
const updates = {
|
||||
status: 'blocked',
|
||||
errorMessage: reason,
|
||||
blockedAt: new Date().toISOString()
|
||||
};
|
||||
}
|
||||
|
||||
await client.hset(
|
||||
`${this.ACCOUNT_KEY_PREFIX}${accountId}`,
|
||||
updates
|
||||
);
|
||||
await client.hset(`${this.ACCOUNT_KEY_PREFIX}${accountId}`, updates)
|
||||
|
||||
logger.warn(`🚫 Claude Console account blocked: ${accountId} - ${reason}`);
|
||||
return { success: true };
|
||||
logger.warn(`🚫 Claude Console account blocked: ${accountId} - ${reason}`)
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to block Claude Console account: ${accountId}`, error);
|
||||
throw error;
|
||||
logger.error(`❌ Failed to block Claude Console account: ${accountId}`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🌐 创建代理agent
|
||||
_createProxyAgent(proxyConfig) {
|
||||
if (!proxyConfig) {
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const proxy = typeof proxyConfig === 'string' ? JSON.parse(proxyConfig) : proxyConfig;
|
||||
|
||||
const proxy = typeof proxyConfig === 'string' ? JSON.parse(proxyConfig) : proxyConfig
|
||||
|
||||
if (proxy.type === 'socks5') {
|
||||
const auth = proxy.username && proxy.password ? `${proxy.username}:${proxy.password}@` : '';
|
||||
const socksUrl = `socks5://${auth}${proxy.host}:${proxy.port}`;
|
||||
return new SocksProxyAgent(socksUrl);
|
||||
const auth = proxy.username && proxy.password ? `${proxy.username}:${proxy.password}@` : ''
|
||||
const socksUrl = `socks5://${auth}${proxy.host}:${proxy.port}`
|
||||
return new SocksProxyAgent(socksUrl)
|
||||
} else if (proxy.type === 'http' || proxy.type === 'https') {
|
||||
const auth = proxy.username && proxy.password ? `${proxy.username}:${proxy.password}@` : '';
|
||||
const httpUrl = `${proxy.type}://${auth}${proxy.host}:${proxy.port}`;
|
||||
return new HttpsProxyAgent(httpUrl);
|
||||
const auth = proxy.username && proxy.password ? `${proxy.username}:${proxy.password}@` : ''
|
||||
const httpUrl = `${proxy.type}://${auth}${proxy.host}:${proxy.port}`
|
||||
return new HttpsProxyAgent(httpUrl)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('⚠️ Invalid proxy configuration:', error);
|
||||
logger.warn('⚠️ Invalid proxy configuration:', error)
|
||||
}
|
||||
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
// 🔐 加密敏感数据
|
||||
_encryptSensitiveData(data) {
|
||||
if (!data) return '';
|
||||
|
||||
if (!data) {
|
||||
return ''
|
||||
}
|
||||
|
||||
try {
|
||||
const key = this._generateEncryptionKey();
|
||||
const iv = crypto.randomBytes(16);
|
||||
|
||||
const cipher = crypto.createCipheriv(this.ENCRYPTION_ALGORITHM, key, iv);
|
||||
let encrypted = cipher.update(data, 'utf8', 'hex');
|
||||
encrypted += cipher.final('hex');
|
||||
|
||||
return iv.toString('hex') + ':' + encrypted;
|
||||
const key = this._generateEncryptionKey()
|
||||
const iv = crypto.randomBytes(16)
|
||||
|
||||
const cipher = crypto.createCipheriv(this.ENCRYPTION_ALGORITHM, key, iv)
|
||||
let encrypted = cipher.update(data, 'utf8', 'hex')
|
||||
encrypted += cipher.final('hex')
|
||||
|
||||
return `${iv.toString('hex')}:${encrypted}`
|
||||
} catch (error) {
|
||||
logger.error('❌ Encryption error:', error);
|
||||
return data;
|
||||
logger.error('❌ Encryption error:', error)
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// 🔓 解密敏感数据
|
||||
_decryptSensitiveData(encryptedData) {
|
||||
if (!encryptedData) return '';
|
||||
|
||||
if (!encryptedData) {
|
||||
return ''
|
||||
}
|
||||
|
||||
try {
|
||||
if (encryptedData.includes(':')) {
|
||||
const parts = encryptedData.split(':');
|
||||
const parts = encryptedData.split(':')
|
||||
if (parts.length === 2) {
|
||||
const key = this._generateEncryptionKey();
|
||||
const iv = Buffer.from(parts[0], 'hex');
|
||||
const encrypted = parts[1];
|
||||
|
||||
const decipher = crypto.createDecipheriv(this.ENCRYPTION_ALGORITHM, key, iv);
|
||||
let decrypted = decipher.update(encrypted, 'hex', 'utf8');
|
||||
decrypted += decipher.final('utf8');
|
||||
return decrypted;
|
||||
const key = this._generateEncryptionKey()
|
||||
const iv = Buffer.from(parts[0], 'hex')
|
||||
const encrypted = parts[1]
|
||||
|
||||
const decipher = crypto.createDecipheriv(this.ENCRYPTION_ALGORITHM, key, iv)
|
||||
let decrypted = decipher.update(encrypted, 'hex', 'utf8')
|
||||
decrypted += decipher.final('utf8')
|
||||
return decrypted
|
||||
}
|
||||
}
|
||||
|
||||
return encryptedData;
|
||||
|
||||
return encryptedData
|
||||
} catch (error) {
|
||||
logger.error('❌ Decryption error:', error);
|
||||
return encryptedData;
|
||||
logger.error('❌ Decryption error:', error)
|
||||
return encryptedData
|
||||
}
|
||||
}
|
||||
|
||||
// 🔑 生成加密密钥
|
||||
_generateEncryptionKey() {
|
||||
return crypto.scryptSync(config.security.encryptionKey, this.ENCRYPTION_SALT, 32);
|
||||
return crypto.scryptSync(config.security.encryptionKey, this.ENCRYPTION_SALT, 32)
|
||||
}
|
||||
|
||||
// 🎭 掩码API URL
|
||||
_maskApiUrl(apiUrl) {
|
||||
if (!apiUrl) return '';
|
||||
|
||||
if (!apiUrl) {
|
||||
return ''
|
||||
}
|
||||
|
||||
try {
|
||||
const url = new URL(apiUrl);
|
||||
return `${url.protocol}//${url.hostname}/***`;
|
||||
const url = new URL(apiUrl)
|
||||
return `${url.protocol}//${url.hostname}/***`
|
||||
} catch {
|
||||
return '***';
|
||||
return '***'
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 获取限流信息
|
||||
_getRateLimitInfo(accountData) {
|
||||
if (accountData.rateLimitStatus === 'limited' && accountData.rateLimitedAt) {
|
||||
const rateLimitedAt = new Date(accountData.rateLimitedAt);
|
||||
const now = new Date();
|
||||
const minutesSinceRateLimit = Math.floor((now - rateLimitedAt) / (1000 * 60));
|
||||
const rateLimitDuration = parseInt(accountData.rateLimitDuration) || 60;
|
||||
const minutesRemaining = Math.max(0, rateLimitDuration - minutesSinceRateLimit);
|
||||
const rateLimitedAt = new Date(accountData.rateLimitedAt)
|
||||
const now = new Date()
|
||||
const minutesSinceRateLimit = Math.floor((now - rateLimitedAt) / (1000 * 60))
|
||||
const rateLimitDuration = parseInt(accountData.rateLimitDuration) || 60
|
||||
const minutesRemaining = Math.max(0, rateLimitDuration - minutesSinceRateLimit)
|
||||
|
||||
return {
|
||||
isRateLimited: minutesRemaining > 0,
|
||||
rateLimitedAt: accountData.rateLimitedAt,
|
||||
minutesSinceRateLimit,
|
||||
minutesRemaining
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -491,57 +513,57 @@ class ClaudeConsoleAccountService {
|
||||
rateLimitedAt: null,
|
||||
minutesSinceRateLimit: 0,
|
||||
minutesRemaining: 0
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 🔄 处理模型映射,确保向后兼容
|
||||
_processModelMapping(supportedModels) {
|
||||
// 如果是空值,返回空对象(支持所有模型)
|
||||
if (!supportedModels || (Array.isArray(supportedModels) && supportedModels.length === 0)) {
|
||||
return {};
|
||||
return {}
|
||||
}
|
||||
|
||||
// 如果已经是对象格式(新的映射表格式),直接返回
|
||||
if (typeof supportedModels === 'object' && !Array.isArray(supportedModels)) {
|
||||
return supportedModels;
|
||||
return supportedModels
|
||||
}
|
||||
|
||||
// 如果是数组格式(旧格式),转换为映射表
|
||||
if (Array.isArray(supportedModels)) {
|
||||
const mapping = {};
|
||||
supportedModels.forEach(model => {
|
||||
const mapping = {}
|
||||
supportedModels.forEach((model) => {
|
||||
if (model && typeof model === 'string') {
|
||||
mapping[model] = model; // 映射到自身
|
||||
mapping[model] = model // 映射到自身
|
||||
}
|
||||
});
|
||||
return mapping;
|
||||
})
|
||||
return mapping
|
||||
}
|
||||
|
||||
// 其他情况返回空对象
|
||||
return {};
|
||||
return {}
|
||||
}
|
||||
|
||||
// 🔍 检查模型是否支持(用于调度)
|
||||
isModelSupported(modelMapping, requestedModel) {
|
||||
// 如果映射表为空,支持所有模型
|
||||
if (!modelMapping || Object.keys(modelMapping).length === 0) {
|
||||
return true;
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查请求的模型是否在映射表的键中
|
||||
return Object.prototype.hasOwnProperty.call(modelMapping, requestedModel);
|
||||
return Object.prototype.hasOwnProperty.call(modelMapping, requestedModel)
|
||||
}
|
||||
|
||||
// 🔄 获取映射后的模型名称
|
||||
getMappedModel(modelMapping, requestedModel) {
|
||||
// 如果映射表为空,返回原模型
|
||||
if (!modelMapping || Object.keys(modelMapping).length === 0) {
|
||||
return requestedModel;
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// 返回映射后的模型,如果不存在则返回原模型
|
||||
return modelMapping[requestedModel] || requestedModel;
|
||||
return modelMapping[requestedModel] || requestedModel
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new ClaudeConsoleAccountService();
|
||||
module.exports = new ClaudeConsoleAccountService()
|
||||
|
||||
@@ -1,37 +1,54 @@
|
||||
const axios = require('axios');
|
||||
const claudeConsoleAccountService = require('./claudeConsoleAccountService');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const axios = require('axios')
|
||||
const claudeConsoleAccountService = require('./claudeConsoleAccountService')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
|
||||
class ClaudeConsoleRelayService {
|
||||
constructor() {
|
||||
this.defaultUserAgent = 'claude-cli/1.0.69 (external, cli)';
|
||||
this.defaultUserAgent = 'claude-cli/1.0.69 (external, cli)'
|
||||
}
|
||||
|
||||
// 🚀 转发请求到Claude Console API
|
||||
async relayRequest(requestBody, apiKeyData, clientRequest, clientResponse, clientHeaders, accountId, options = {}) {
|
||||
let abortController = null;
|
||||
|
||||
async relayRequest(
|
||||
requestBody,
|
||||
apiKeyData,
|
||||
clientRequest,
|
||||
clientResponse,
|
||||
clientHeaders,
|
||||
accountId,
|
||||
options = {}
|
||||
) {
|
||||
let abortController = null
|
||||
|
||||
try {
|
||||
// 获取账户信息
|
||||
const account = await claudeConsoleAccountService.getAccount(accountId);
|
||||
const account = await claudeConsoleAccountService.getAccount(accountId)
|
||||
if (!account) {
|
||||
throw new Error('Claude Console Claude account not found');
|
||||
throw new Error('Claude Console Claude account not found')
|
||||
}
|
||||
|
||||
logger.info(`📤 Processing Claude Console API request for key: ${apiKeyData.name || apiKeyData.id}, account: ${account.name} (${accountId})`);
|
||||
logger.debug(`🌐 Account API URL: ${account.apiUrl}`);
|
||||
logger.debug(`🔍 Account supportedModels: ${JSON.stringify(account.supportedModels)}`);
|
||||
logger.debug(`🔑 Account has apiKey: ${!!account.apiKey}`);
|
||||
logger.debug(`📝 Request model: ${requestBody.model}`);
|
||||
logger.info(
|
||||
`📤 Processing Claude Console API request for key: ${apiKeyData.name || apiKeyData.id}, account: ${account.name} (${accountId})`
|
||||
)
|
||||
logger.debug(`🌐 Account API URL: ${account.apiUrl}`)
|
||||
logger.debug(`🔍 Account supportedModels: ${JSON.stringify(account.supportedModels)}`)
|
||||
logger.debug(`🔑 Account has apiKey: ${!!account.apiKey}`)
|
||||
logger.debug(`📝 Request model: ${requestBody.model}`)
|
||||
|
||||
// 处理模型映射
|
||||
let mappedModel = requestBody.model;
|
||||
if (account.supportedModels && typeof account.supportedModels === 'object' && !Array.isArray(account.supportedModels)) {
|
||||
const newModel = claudeConsoleAccountService.getMappedModel(account.supportedModels, requestBody.model);
|
||||
let mappedModel = requestBody.model
|
||||
if (
|
||||
account.supportedModels &&
|
||||
typeof account.supportedModels === 'object' &&
|
||||
!Array.isArray(account.supportedModels)
|
||||
) {
|
||||
const newModel = claudeConsoleAccountService.getMappedModel(
|
||||
account.supportedModels,
|
||||
requestBody.model
|
||||
)
|
||||
if (newModel !== requestBody.model) {
|
||||
logger.info(`🔄 Mapping model from ${requestBody.model} to ${newModel}`);
|
||||
mappedModel = newModel;
|
||||
logger.info(`🔄 Mapping model from ${requestBody.model} to ${newModel}`)
|
||||
mappedModel = newModel
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,52 +56,51 @@ class ClaudeConsoleRelayService {
|
||||
const modifiedRequestBody = {
|
||||
...requestBody,
|
||||
model: mappedModel
|
||||
};
|
||||
}
|
||||
|
||||
// 模型兼容性检查已经在调度器中完成,这里不需要再检查
|
||||
|
||||
// 创建代理agent
|
||||
const proxyAgent = claudeConsoleAccountService._createProxyAgent(account.proxy);
|
||||
const proxyAgent = claudeConsoleAccountService._createProxyAgent(account.proxy)
|
||||
|
||||
// 创建AbortController用于取消请求
|
||||
abortController = new AbortController();
|
||||
abortController = new AbortController()
|
||||
|
||||
// 设置客户端断开监听器
|
||||
const handleClientDisconnect = () => {
|
||||
logger.info('🔌 Client disconnected, aborting Claude Console Claude request');
|
||||
logger.info('🔌 Client disconnected, aborting Claude Console Claude request')
|
||||
if (abortController && !abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
abortController.abort()
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 监听客户端断开事件
|
||||
if (clientRequest) {
|
||||
clientRequest.once('close', handleClientDisconnect);
|
||||
clientRequest.once('close', handleClientDisconnect)
|
||||
}
|
||||
if (clientResponse) {
|
||||
clientResponse.once('close', handleClientDisconnect);
|
||||
clientResponse.once('close', handleClientDisconnect)
|
||||
}
|
||||
|
||||
// 构建完整的API URL
|
||||
const cleanUrl = account.apiUrl.replace(/\/$/, ''); // 移除末尾斜杠
|
||||
const apiEndpoint = cleanUrl.endsWith('/v1/messages')
|
||||
? cleanUrl
|
||||
: `${cleanUrl}/v1/messages`;
|
||||
|
||||
logger.debug(`🎯 Final API endpoint: ${apiEndpoint}`);
|
||||
logger.debug(`[DEBUG] Options passed to relayRequest: ${JSON.stringify(options)}`);
|
||||
logger.debug(`[DEBUG] Client headers received: ${JSON.stringify(clientHeaders)}`);
|
||||
|
||||
const cleanUrl = account.apiUrl.replace(/\/$/, '') // 移除末尾斜杠
|
||||
const apiEndpoint = cleanUrl.endsWith('/v1/messages') ? cleanUrl : `${cleanUrl}/v1/messages`
|
||||
|
||||
logger.debug(`🎯 Final API endpoint: ${apiEndpoint}`)
|
||||
logger.debug(`[DEBUG] Options passed to relayRequest: ${JSON.stringify(options)}`)
|
||||
logger.debug(`[DEBUG] Client headers received: ${JSON.stringify(clientHeaders)}`)
|
||||
|
||||
// 过滤客户端请求头
|
||||
const filteredHeaders = this._filterClientHeaders(clientHeaders);
|
||||
logger.debug(`[DEBUG] Filtered client headers: ${JSON.stringify(filteredHeaders)}`);
|
||||
|
||||
const filteredHeaders = this._filterClientHeaders(clientHeaders)
|
||||
logger.debug(`[DEBUG] Filtered client headers: ${JSON.stringify(filteredHeaders)}`)
|
||||
|
||||
// 决定使用的 User-Agent:优先使用账户自定义的,否则透传客户端的,最后才使用默认值
|
||||
const userAgent = account.userAgent ||
|
||||
clientHeaders?.['user-agent'] ||
|
||||
clientHeaders?.['User-Agent'] ||
|
||||
this.defaultUserAgent;
|
||||
|
||||
const userAgent =
|
||||
account.userAgent ||
|
||||
clientHeaders?.['user-agent'] ||
|
||||
clientHeaders?.['User-Agent'] ||
|
||||
this.defaultUserAgent
|
||||
|
||||
// 准备请求配置
|
||||
const requestConfig = {
|
||||
method: 'POST',
|
||||
@@ -100,103 +116,123 @@ class ClaudeConsoleRelayService {
|
||||
timeout: config.proxy.timeout || 60000,
|
||||
signal: abortController.signal,
|
||||
validateStatus: () => true // 接受所有状态码
|
||||
};
|
||||
}
|
||||
|
||||
// 根据 API Key 格式选择认证方式
|
||||
if (account.apiKey && account.apiKey.startsWith('sk-ant-')) {
|
||||
// Anthropic 官方 API Key 使用 x-api-key
|
||||
requestConfig.headers['x-api-key'] = account.apiKey;
|
||||
logger.debug('[DEBUG] Using x-api-key authentication for sk-ant-* API key');
|
||||
requestConfig.headers['x-api-key'] = account.apiKey
|
||||
logger.debug('[DEBUG] Using x-api-key authentication for sk-ant-* API key')
|
||||
} else {
|
||||
// 其他 API Key 使用 Authorization Bearer
|
||||
requestConfig.headers['Authorization'] = `Bearer ${account.apiKey}`;
|
||||
logger.debug('[DEBUG] Using Authorization Bearer authentication');
|
||||
requestConfig.headers['Authorization'] = `Bearer ${account.apiKey}`
|
||||
logger.debug('[DEBUG] Using Authorization Bearer authentication')
|
||||
}
|
||||
|
||||
logger.debug(`[DEBUG] Initial headers before beta: ${JSON.stringify(requestConfig.headers, null, 2)}`);
|
||||
|
||||
|
||||
logger.debug(
|
||||
`[DEBUG] Initial headers before beta: ${JSON.stringify(requestConfig.headers, null, 2)}`
|
||||
)
|
||||
|
||||
// 添加beta header如果需要
|
||||
if (options.betaHeader) {
|
||||
logger.debug(`[DEBUG] Adding beta header: ${options.betaHeader}`);
|
||||
requestConfig.headers['anthropic-beta'] = options.betaHeader;
|
||||
logger.debug(`[DEBUG] Adding beta header: ${options.betaHeader}`)
|
||||
requestConfig.headers['anthropic-beta'] = options.betaHeader
|
||||
} else {
|
||||
logger.debug('[DEBUG] No beta header to add');
|
||||
logger.debug('[DEBUG] No beta header to add')
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
logger.debug('📤 Sending request to Claude Console API with headers:', JSON.stringify(requestConfig.headers, null, 2));
|
||||
const response = await axios(requestConfig);
|
||||
logger.debug(
|
||||
'📤 Sending request to Claude Console API with headers:',
|
||||
JSON.stringify(requestConfig.headers, null, 2)
|
||||
)
|
||||
const response = await axios(requestConfig)
|
||||
|
||||
// 移除监听器(请求成功完成)
|
||||
if (clientRequest) {
|
||||
clientRequest.removeListener('close', handleClientDisconnect);
|
||||
clientRequest.removeListener('close', handleClientDisconnect)
|
||||
}
|
||||
if (clientResponse) {
|
||||
clientResponse.removeListener('close', handleClientDisconnect);
|
||||
clientResponse.removeListener('close', handleClientDisconnect)
|
||||
}
|
||||
|
||||
logger.debug(`🔗 Claude Console API response: ${response.status}`);
|
||||
logger.debug(`[DEBUG] Response headers: ${JSON.stringify(response.headers)}`);
|
||||
logger.debug(`[DEBUG] Response data type: ${typeof response.data}`);
|
||||
logger.debug(`[DEBUG] Response data length: ${response.data ? (typeof response.data === 'string' ? response.data.length : JSON.stringify(response.data).length) : 0}`);
|
||||
logger.debug(`[DEBUG] Response data preview: ${typeof response.data === 'string' ? response.data.substring(0, 200) : JSON.stringify(response.data).substring(0, 200)}`);
|
||||
logger.debug(`🔗 Claude Console API response: ${response.status}`)
|
||||
logger.debug(`[DEBUG] Response headers: ${JSON.stringify(response.headers)}`)
|
||||
logger.debug(`[DEBUG] Response data type: ${typeof response.data}`)
|
||||
logger.debug(
|
||||
`[DEBUG] Response data length: ${response.data ? (typeof response.data === 'string' ? response.data.length : JSON.stringify(response.data).length) : 0}`
|
||||
)
|
||||
logger.debug(
|
||||
`[DEBUG] Response data preview: ${typeof response.data === 'string' ? response.data.substring(0, 200) : JSON.stringify(response.data).substring(0, 200)}`
|
||||
)
|
||||
|
||||
// 检查是否为限流错误
|
||||
if (response.status === 429) {
|
||||
logger.warn(`🚫 Rate limit detected for Claude Console account ${accountId}`);
|
||||
await claudeConsoleAccountService.markAccountRateLimited(accountId);
|
||||
logger.warn(`🚫 Rate limit detected for Claude Console account ${accountId}`)
|
||||
await claudeConsoleAccountService.markAccountRateLimited(accountId)
|
||||
} else if (response.status === 200 || response.status === 201) {
|
||||
// 如果请求成功,检查并移除限流状态
|
||||
const isRateLimited = await claudeConsoleAccountService.isAccountRateLimited(accountId);
|
||||
const isRateLimited = await claudeConsoleAccountService.isAccountRateLimited(accountId)
|
||||
if (isRateLimited) {
|
||||
await claudeConsoleAccountService.removeAccountRateLimit(accountId);
|
||||
await claudeConsoleAccountService.removeAccountRateLimit(accountId)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新最后使用时间
|
||||
await this._updateLastUsedTime(accountId);
|
||||
await this._updateLastUsedTime(accountId)
|
||||
|
||||
const responseBody = typeof response.data === 'string' ? response.data : JSON.stringify(response.data);
|
||||
logger.debug(`[DEBUG] Final response body to return: ${responseBody}`);
|
||||
const responseBody =
|
||||
typeof response.data === 'string' ? response.data : JSON.stringify(response.data)
|
||||
logger.debug(`[DEBUG] Final response body to return: ${responseBody}`)
|
||||
|
||||
return {
|
||||
statusCode: response.status,
|
||||
headers: response.headers,
|
||||
body: responseBody,
|
||||
accountId
|
||||
};
|
||||
|
||||
}
|
||||
} catch (error) {
|
||||
// 处理特定错误
|
||||
if (error.name === 'AbortError' || error.code === 'ECONNABORTED') {
|
||||
logger.info('Request aborted due to client disconnect');
|
||||
throw new Error('Client disconnected');
|
||||
logger.info('Request aborted due to client disconnect')
|
||||
throw new Error('Client disconnected')
|
||||
}
|
||||
|
||||
logger.error('❌ Claude Console Claude relay request failed:', error.message);
|
||||
|
||||
logger.error('❌ Claude Console Claude relay request failed:', error.message)
|
||||
|
||||
// 不再因为模型不支持而block账号
|
||||
|
||||
throw error;
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🌊 处理流式响应
|
||||
async relayStreamRequestWithUsageCapture(requestBody, apiKeyData, responseStream, clientHeaders, usageCallback, accountId, streamTransformer = null, options = {}) {
|
||||
async relayStreamRequestWithUsageCapture(
|
||||
requestBody,
|
||||
apiKeyData,
|
||||
responseStream,
|
||||
clientHeaders,
|
||||
usageCallback,
|
||||
accountId,
|
||||
streamTransformer = null,
|
||||
options = {}
|
||||
) {
|
||||
try {
|
||||
// 获取账户信息
|
||||
const account = await claudeConsoleAccountService.getAccount(accountId);
|
||||
const account = await claudeConsoleAccountService.getAccount(accountId)
|
||||
if (!account) {
|
||||
throw new Error('Claude Console Claude account not found');
|
||||
throw new Error('Claude Console Claude account not found')
|
||||
}
|
||||
|
||||
logger.info(`📡 Processing streaming Claude Console API request for key: ${apiKeyData.name || apiKeyData.id}, account: ${account.name} (${accountId})`);
|
||||
logger.debug(`🌐 Account API URL: ${account.apiUrl}`);
|
||||
logger.info(
|
||||
`📡 Processing streaming Claude Console API request for key: ${apiKeyData.name || apiKeyData.id}, account: ${account.name} (${accountId})`
|
||||
)
|
||||
logger.debug(`🌐 Account API URL: ${account.apiUrl}`)
|
||||
|
||||
// 模型兼容性检查已经在调度器中完成,这里不需要再检查
|
||||
|
||||
// 创建代理agent
|
||||
const proxyAgent = claudeConsoleAccountService._createProxyAgent(account.proxy);
|
||||
const proxyAgent = claudeConsoleAccountService._createProxyAgent(account.proxy)
|
||||
|
||||
// 发送流式请求
|
||||
await this._makeClaudeConsoleStreamRequest(
|
||||
@@ -209,40 +245,48 @@ class ClaudeConsoleRelayService {
|
||||
usageCallback,
|
||||
streamTransformer,
|
||||
options
|
||||
);
|
||||
)
|
||||
|
||||
// 更新最后使用时间
|
||||
await this._updateLastUsedTime(accountId);
|
||||
|
||||
await this._updateLastUsedTime(accountId)
|
||||
} catch (error) {
|
||||
logger.error('❌ Claude Console Claude stream relay failed:', error);
|
||||
throw error;
|
||||
logger.error('❌ Claude Console Claude stream relay failed:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🌊 发送流式请求到Claude Console API
|
||||
async _makeClaudeConsoleStreamRequest(body, account, proxyAgent, clientHeaders, responseStream, accountId, usageCallback, streamTransformer = null, requestOptions = {}) {
|
||||
async _makeClaudeConsoleStreamRequest(
|
||||
body,
|
||||
account,
|
||||
proxyAgent,
|
||||
clientHeaders,
|
||||
responseStream,
|
||||
accountId,
|
||||
usageCallback,
|
||||
streamTransformer = null,
|
||||
requestOptions = {}
|
||||
) {
|
||||
return new Promise((resolve, reject) => {
|
||||
let aborted = false;
|
||||
let aborted = false
|
||||
|
||||
// 构建完整的API URL
|
||||
const cleanUrl = account.apiUrl.replace(/\/$/, ''); // 移除末尾斜杠
|
||||
const apiEndpoint = cleanUrl.endsWith('/v1/messages')
|
||||
? cleanUrl
|
||||
: `${cleanUrl}/v1/messages`;
|
||||
|
||||
logger.debug(`🎯 Final API endpoint for stream: ${apiEndpoint}`);
|
||||
const cleanUrl = account.apiUrl.replace(/\/$/, '') // 移除末尾斜杠
|
||||
const apiEndpoint = cleanUrl.endsWith('/v1/messages') ? cleanUrl : `${cleanUrl}/v1/messages`
|
||||
|
||||
logger.debug(`🎯 Final API endpoint for stream: ${apiEndpoint}`)
|
||||
|
||||
// 过滤客户端请求头
|
||||
const filteredHeaders = this._filterClientHeaders(clientHeaders);
|
||||
logger.debug(`[DEBUG] Filtered client headers: ${JSON.stringify(filteredHeaders)}`);
|
||||
|
||||
const filteredHeaders = this._filterClientHeaders(clientHeaders)
|
||||
logger.debug(`[DEBUG] Filtered client headers: ${JSON.stringify(filteredHeaders)}`)
|
||||
|
||||
// 决定使用的 User-Agent:优先使用账户自定义的,否则透传客户端的,最后才使用默认值
|
||||
const userAgent = account.userAgent ||
|
||||
clientHeaders?.['user-agent'] ||
|
||||
clientHeaders?.['User-Agent'] ||
|
||||
this.defaultUserAgent;
|
||||
|
||||
const userAgent =
|
||||
account.userAgent ||
|
||||
clientHeaders?.['user-agent'] ||
|
||||
clientHeaders?.['User-Agent'] ||
|
||||
this.defaultUserAgent
|
||||
|
||||
// 准备请求配置
|
||||
const requestConfig = {
|
||||
method: 'POST',
|
||||
@@ -258,237 +302,254 @@ class ClaudeConsoleRelayService {
|
||||
timeout: config.proxy.timeout || 60000,
|
||||
responseType: 'stream',
|
||||
validateStatus: () => true // 接受所有状态码
|
||||
};
|
||||
}
|
||||
|
||||
// 根据 API Key 格式选择认证方式
|
||||
if (account.apiKey && account.apiKey.startsWith('sk-ant-')) {
|
||||
// Anthropic 官方 API Key 使用 x-api-key
|
||||
requestConfig.headers['x-api-key'] = account.apiKey;
|
||||
logger.debug('[DEBUG] Using x-api-key authentication for sk-ant-* API key');
|
||||
requestConfig.headers['x-api-key'] = account.apiKey
|
||||
logger.debug('[DEBUG] Using x-api-key authentication for sk-ant-* API key')
|
||||
} else {
|
||||
// 其他 API Key 使用 Authorization Bearer
|
||||
requestConfig.headers['Authorization'] = `Bearer ${account.apiKey}`;
|
||||
logger.debug('[DEBUG] Using Authorization Bearer authentication');
|
||||
requestConfig.headers['Authorization'] = `Bearer ${account.apiKey}`
|
||||
logger.debug('[DEBUG] Using Authorization Bearer authentication')
|
||||
}
|
||||
|
||||
|
||||
// 添加beta header如果需要
|
||||
if (requestOptions.betaHeader) {
|
||||
requestConfig.headers['anthropic-beta'] = requestOptions.betaHeader;
|
||||
requestConfig.headers['anthropic-beta'] = requestOptions.betaHeader
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
const request = axios(requestConfig);
|
||||
const request = axios(requestConfig)
|
||||
|
||||
request.then(response => {
|
||||
logger.debug(`🌊 Claude Console Claude stream response status: ${response.status}`);
|
||||
request
|
||||
.then((response) => {
|
||||
logger.debug(`🌊 Claude Console Claude stream response status: ${response.status}`)
|
||||
|
||||
// 错误响应处理
|
||||
if (response.status !== 200) {
|
||||
logger.error(`❌ Claude Console API returned error status: ${response.status}`);
|
||||
|
||||
if (response.status === 429) {
|
||||
claudeConsoleAccountService.markAccountRateLimited(accountId);
|
||||
// 错误响应处理
|
||||
if (response.status !== 200) {
|
||||
logger.error(`❌ Claude Console API returned error status: ${response.status}`)
|
||||
|
||||
if (response.status === 429) {
|
||||
claudeConsoleAccountService.markAccountRateLimited(accountId)
|
||||
}
|
||||
|
||||
// 设置错误响应的状态码和响应头
|
||||
if (!responseStream.headersSent) {
|
||||
const errorHeaders = {
|
||||
'Content-Type': response.headers['content-type'] || 'application/json',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive'
|
||||
}
|
||||
// 避免 Transfer-Encoding 冲突,让 Express 自动处理
|
||||
delete errorHeaders['Transfer-Encoding']
|
||||
delete errorHeaders['Content-Length']
|
||||
responseStream.writeHead(response.status, errorHeaders)
|
||||
}
|
||||
|
||||
// 直接透传错误数据,不进行包装
|
||||
response.data.on('data', (chunk) => {
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write(chunk)
|
||||
}
|
||||
})
|
||||
|
||||
response.data.on('end', () => {
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.end()
|
||||
}
|
||||
resolve() // 不抛出异常,正常完成流处理
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置错误响应的状态码和响应头
|
||||
// 成功响应,检查并移除限流状态
|
||||
claudeConsoleAccountService.isAccountRateLimited(accountId).then((isRateLimited) => {
|
||||
if (isRateLimited) {
|
||||
claudeConsoleAccountService.removeAccountRateLimit(accountId)
|
||||
}
|
||||
})
|
||||
|
||||
// 设置响应头
|
||||
if (!responseStream.headersSent) {
|
||||
const errorHeaders = {
|
||||
'Content-Type': response.headers['content-type'] || 'application/json',
|
||||
responseStream.writeHead(200, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive'
|
||||
};
|
||||
// 避免 Transfer-Encoding 冲突,让 Express 自动处理
|
||||
delete errorHeaders['Transfer-Encoding'];
|
||||
delete errorHeaders['Content-Length'];
|
||||
responseStream.writeHead(response.status, errorHeaders);
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no'
|
||||
})
|
||||
}
|
||||
|
||||
// 直接透传错误数据,不进行包装
|
||||
response.data.on('data', chunk => {
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write(chunk);
|
||||
}
|
||||
});
|
||||
let buffer = ''
|
||||
let finalUsageReported = false
|
||||
const collectedUsageData = {}
|
||||
|
||||
response.data.on('end', () => {
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.end();
|
||||
}
|
||||
resolve(); // 不抛出异常,正常完成流处理
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// 成功响应,检查并移除限流状态
|
||||
claudeConsoleAccountService.isAccountRateLimited(accountId).then(isRateLimited => {
|
||||
if (isRateLimited) {
|
||||
claudeConsoleAccountService.removeAccountRateLimit(accountId);
|
||||
}
|
||||
});
|
||||
|
||||
// 设置响应头
|
||||
if (!responseStream.headersSent) {
|
||||
responseStream.writeHead(200, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
'X-Accel-Buffering': 'no'
|
||||
});
|
||||
}
|
||||
|
||||
let buffer = '';
|
||||
let finalUsageReported = false;
|
||||
let collectedUsageData = {};
|
||||
|
||||
// 处理流数据
|
||||
response.data.on('data', chunk => {
|
||||
try {
|
||||
if (aborted) return;
|
||||
|
||||
const chunkStr = chunk.toString();
|
||||
buffer += chunkStr;
|
||||
|
||||
// 处理完整的SSE行
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || '';
|
||||
|
||||
// 转发数据并解析usage
|
||||
if (lines.length > 0 && !responseStream.destroyed) {
|
||||
const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '');
|
||||
|
||||
// 应用流转换器如果有
|
||||
if (streamTransformer) {
|
||||
const transformed = streamTransformer(linesToForward);
|
||||
if (transformed) {
|
||||
responseStream.write(transformed);
|
||||
}
|
||||
} else {
|
||||
responseStream.write(linesToForward);
|
||||
// 处理流数据
|
||||
response.data.on('data', (chunk) => {
|
||||
try {
|
||||
if (aborted) {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析SSE数据寻找usage信息
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ') && line.length > 6) {
|
||||
try {
|
||||
const jsonStr = line.slice(6);
|
||||
const data = JSON.parse(jsonStr);
|
||||
|
||||
// 收集usage数据
|
||||
if (data.type === 'message_start' && data.message && data.message.usage) {
|
||||
collectedUsageData.input_tokens = data.message.usage.input_tokens || 0;
|
||||
collectedUsageData.cache_creation_input_tokens = data.message.usage.cache_creation_input_tokens || 0;
|
||||
collectedUsageData.cache_read_input_tokens = data.message.usage.cache_read_input_tokens || 0;
|
||||
collectedUsageData.model = data.message.model;
|
||||
}
|
||||
|
||||
if (data.type === 'message_delta' && data.usage && data.usage.output_tokens !== undefined) {
|
||||
collectedUsageData.output_tokens = data.usage.output_tokens || 0;
|
||||
|
||||
if (collectedUsageData.input_tokens !== undefined && !finalUsageReported) {
|
||||
usageCallback({ ...collectedUsageData, accountId });
|
||||
finalUsageReported = true;
|
||||
}
|
||||
}
|
||||
const chunkStr = chunk.toString()
|
||||
buffer += chunkStr
|
||||
|
||||
// 不再因为模型不支持而block账号
|
||||
} catch (e) {
|
||||
// 忽略解析错误
|
||||
// 处理完整的SSE行
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
// 转发数据并解析usage
|
||||
if (lines.length > 0 && !responseStream.destroyed) {
|
||||
const linesToForward = lines.join('\n') + (lines.length > 0 ? '\n' : '')
|
||||
|
||||
// 应用流转换器如果有
|
||||
if (streamTransformer) {
|
||||
const transformed = streamTransformer(linesToForward)
|
||||
if (transformed) {
|
||||
responseStream.write(transformed)
|
||||
}
|
||||
} else {
|
||||
responseStream.write(linesToForward)
|
||||
}
|
||||
|
||||
// 解析SSE数据寻找usage信息
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ') && line.length > 6) {
|
||||
try {
|
||||
const jsonStr = line.slice(6)
|
||||
const data = JSON.parse(jsonStr)
|
||||
|
||||
// 收集usage数据
|
||||
if (data.type === 'message_start' && data.message && data.message.usage) {
|
||||
collectedUsageData.input_tokens = data.message.usage.input_tokens || 0
|
||||
collectedUsageData.cache_creation_input_tokens =
|
||||
data.message.usage.cache_creation_input_tokens || 0
|
||||
collectedUsageData.cache_read_input_tokens =
|
||||
data.message.usage.cache_read_input_tokens || 0
|
||||
collectedUsageData.model = data.message.model
|
||||
}
|
||||
|
||||
if (
|
||||
data.type === 'message_delta' &&
|
||||
data.usage &&
|
||||
data.usage.output_tokens !== undefined
|
||||
) {
|
||||
collectedUsageData.output_tokens = data.usage.output_tokens || 0
|
||||
|
||||
if (collectedUsageData.input_tokens !== undefined && !finalUsageReported) {
|
||||
usageCallback({ ...collectedUsageData, accountId })
|
||||
finalUsageReported = true
|
||||
}
|
||||
}
|
||||
|
||||
// 不再因为模型不支持而block账号
|
||||
} catch (e) {
|
||||
// 忽略解析错误
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Error processing Claude Console stream data:', error);
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write('event: error\n');
|
||||
responseStream.write(`data: ${JSON.stringify({
|
||||
error: 'Stream processing error',
|
||||
message: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
})}\n\n`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
response.data.on('end', () => {
|
||||
try {
|
||||
// 处理缓冲区中剩余的数据
|
||||
if (buffer.trim() && !responseStream.destroyed) {
|
||||
if (streamTransformer) {
|
||||
const transformed = streamTransformer(buffer);
|
||||
if (transformed) {
|
||||
responseStream.write(transformed);
|
||||
}
|
||||
} else {
|
||||
responseStream.write(buffer);
|
||||
} catch (error) {
|
||||
logger.error('❌ Error processing Claude Console stream data:', error)
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write('event: error\n')
|
||||
responseStream.write(
|
||||
`data: ${JSON.stringify({
|
||||
error: 'Stream processing error',
|
||||
message: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
})}\n\n`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 确保流正确结束
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.end();
|
||||
})
|
||||
|
||||
response.data.on('end', () => {
|
||||
try {
|
||||
// 处理缓冲区中剩余的数据
|
||||
if (buffer.trim() && !responseStream.destroyed) {
|
||||
if (streamTransformer) {
|
||||
const transformed = streamTransformer(buffer)
|
||||
if (transformed) {
|
||||
responseStream.write(transformed)
|
||||
}
|
||||
} else {
|
||||
responseStream.write(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
// 确保流正确结束
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.end()
|
||||
}
|
||||
|
||||
logger.debug('🌊 Claude Console Claude stream response completed')
|
||||
resolve()
|
||||
} catch (error) {
|
||||
logger.error('❌ Error processing stream end:', error)
|
||||
reject(error)
|
||||
}
|
||||
})
|
||||
|
||||
logger.debug('🌊 Claude Console Claude stream response completed');
|
||||
resolve();
|
||||
} catch (error) {
|
||||
logger.error('❌ Error processing stream end:', error);
|
||||
reject(error);
|
||||
response.data.on('error', (error) => {
|
||||
logger.error('❌ Claude Console stream error:', error)
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write('event: error\n')
|
||||
responseStream.write(
|
||||
`data: ${JSON.stringify({
|
||||
error: 'Stream error',
|
||||
message: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
})}\n\n`
|
||||
)
|
||||
responseStream.end()
|
||||
}
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
.catch((error) => {
|
||||
if (aborted) {
|
||||
return
|
||||
}
|
||||
|
||||
logger.error('❌ Claude Console Claude stream request error:', error.message)
|
||||
|
||||
// 检查是否是429错误
|
||||
if (error.response && error.response.status === 429) {
|
||||
claudeConsoleAccountService.markAccountRateLimited(accountId)
|
||||
}
|
||||
|
||||
// 发送错误响应
|
||||
if (!responseStream.headersSent) {
|
||||
responseStream.writeHead(error.response?.status || 500, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive'
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
response.data.on('error', error => {
|
||||
logger.error('❌ Claude Console stream error:', error);
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write('event: error\n');
|
||||
responseStream.write(`data: ${JSON.stringify({
|
||||
error: 'Stream error',
|
||||
message: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
})}\n\n`);
|
||||
responseStream.end();
|
||||
responseStream.write('event: error\n')
|
||||
responseStream.write(
|
||||
`data: ${JSON.stringify({
|
||||
error: error.message,
|
||||
code: error.code,
|
||||
timestamp: new Date().toISOString()
|
||||
})}\n\n`
|
||||
)
|
||||
responseStream.end()
|
||||
}
|
||||
reject(error);
|
||||
});
|
||||
|
||||
}).catch(error => {
|
||||
if (aborted) return;
|
||||
|
||||
logger.error('❌ Claude Console Claude stream request error:', error.message);
|
||||
|
||||
// 检查是否是429错误
|
||||
if (error.response && error.response.status === 429) {
|
||||
claudeConsoleAccountService.markAccountRateLimited(accountId);
|
||||
}
|
||||
|
||||
// 发送错误响应
|
||||
if (!responseStream.headersSent) {
|
||||
responseStream.writeHead(error.response?.status || 500, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive'
|
||||
});
|
||||
}
|
||||
|
||||
if (!responseStream.destroyed) {
|
||||
responseStream.write('event: error\n');
|
||||
responseStream.write(`data: ${JSON.stringify({
|
||||
error: error.message,
|
||||
code: error.code,
|
||||
timestamp: new Date().toISOString()
|
||||
})}\n\n`);
|
||||
responseStream.end();
|
||||
}
|
||||
|
||||
reject(error);
|
||||
});
|
||||
reject(error)
|
||||
})
|
||||
|
||||
// 处理客户端断开连接
|
||||
responseStream.on('close', () => {
|
||||
logger.debug('🔌 Client disconnected, cleaning up Claude Console stream');
|
||||
aborted = true;
|
||||
});
|
||||
});
|
||||
logger.debug('🔌 Client disconnected, cleaning up Claude Console stream')
|
||||
aborted = true
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 🔧 过滤客户端请求头
|
||||
@@ -505,55 +566,58 @@ class ClaudeConsoleRelayService {
|
||||
'content-encoding',
|
||||
'transfer-encoding',
|
||||
'anthropic-version'
|
||||
];
|
||||
|
||||
const filteredHeaders = {};
|
||||
|
||||
Object.keys(clientHeaders || {}).forEach(key => {
|
||||
const lowerKey = key.toLowerCase();
|
||||
]
|
||||
|
||||
const filteredHeaders = {}
|
||||
|
||||
Object.keys(clientHeaders || {}).forEach((key) => {
|
||||
const lowerKey = key.toLowerCase()
|
||||
if (!sensitiveHeaders.includes(lowerKey)) {
|
||||
filteredHeaders[key] = clientHeaders[key];
|
||||
filteredHeaders[key] = clientHeaders[key]
|
||||
}
|
||||
});
|
||||
|
||||
return filteredHeaders;
|
||||
})
|
||||
|
||||
return filteredHeaders
|
||||
}
|
||||
|
||||
// 🕐 更新最后使用时间
|
||||
async _updateLastUsedTime(accountId) {
|
||||
try {
|
||||
const client = require('../models/redis').getClientSafe();
|
||||
const client = require('../models/redis').getClientSafe()
|
||||
await client.hset(
|
||||
`claude_console_account:${accountId}`,
|
||||
'lastUsedAt',
|
||||
new Date().toISOString()
|
||||
);
|
||||
)
|
||||
} catch (error) {
|
||||
logger.warn(`⚠️ Failed to update last used time for Claude Console account ${accountId}:`, error.message);
|
||||
logger.warn(
|
||||
`⚠️ Failed to update last used time for Claude Console account ${accountId}:`,
|
||||
error.message
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 🎯 健康检查
|
||||
async healthCheck() {
|
||||
try {
|
||||
const accounts = await claudeConsoleAccountService.getAllAccounts();
|
||||
const activeAccounts = accounts.filter(acc => acc.isActive && acc.status === 'active');
|
||||
|
||||
const accounts = await claudeConsoleAccountService.getAllAccounts()
|
||||
const activeAccounts = accounts.filter((acc) => acc.isActive && acc.status === 'active')
|
||||
|
||||
return {
|
||||
healthy: activeAccounts.length > 0,
|
||||
activeAccounts: activeAccounts.length,
|
||||
totalAccounts: accounts.length,
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Claude Console Claude health check failed:', error);
|
||||
logger.error('❌ Claude Console Claude health check failed:', error)
|
||||
return {
|
||||
healthy: false,
|
||||
error: error.message,
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new ClaudeConsoleRelayService();
|
||||
module.exports = new ClaudeConsoleRelayService()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
const redis = require('../models/redis');
|
||||
const apiKeyService = require('./apiKeyService');
|
||||
const CostCalculator = require('../utils/costCalculator');
|
||||
const logger = require('../utils/logger');
|
||||
const redis = require('../models/redis')
|
||||
const apiKeyService = require('./apiKeyService')
|
||||
const CostCalculator = require('../utils/costCalculator')
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class CostInitService {
|
||||
/**
|
||||
@@ -10,173 +10,187 @@ class CostInitService {
|
||||
*/
|
||||
async initializeAllCosts() {
|
||||
try {
|
||||
logger.info('💰 Starting cost initialization for all API Keys...');
|
||||
|
||||
const apiKeys = await apiKeyService.getAllApiKeys();
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
let processedCount = 0;
|
||||
let errorCount = 0;
|
||||
|
||||
logger.info('💰 Starting cost initialization for all API Keys...')
|
||||
|
||||
const apiKeys = await apiKeyService.getAllApiKeys()
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
let processedCount = 0
|
||||
let errorCount = 0
|
||||
|
||||
for (const apiKey of apiKeys) {
|
||||
try {
|
||||
await this.initializeApiKeyCosts(apiKey.id, client);
|
||||
processedCount++;
|
||||
|
||||
await this.initializeApiKeyCosts(apiKey.id, client)
|
||||
processedCount++
|
||||
|
||||
if (processedCount % 10 === 0) {
|
||||
logger.info(`💰 Processed ${processedCount} API Keys...`);
|
||||
logger.info(`💰 Processed ${processedCount} API Keys...`)
|
||||
}
|
||||
} catch (error) {
|
||||
errorCount++;
|
||||
logger.error(`❌ Failed to initialize costs for API Key ${apiKey.id}:`, error);
|
||||
errorCount++
|
||||
logger.error(`❌ Failed to initialize costs for API Key ${apiKey.id}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
logger.success(`💰 Cost initialization completed! Processed: ${processedCount}, Errors: ${errorCount}`);
|
||||
return { processed: processedCount, errors: errorCount };
|
||||
|
||||
logger.success(
|
||||
`💰 Cost initialization completed! Processed: ${processedCount}, Errors: ${errorCount}`
|
||||
)
|
||||
return { processed: processedCount, errors: errorCount }
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to initialize costs:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to initialize costs:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 初始化单个API Key的费用数据
|
||||
*/
|
||||
async initializeApiKeyCosts(apiKeyId, client) {
|
||||
// 获取所有时间的模型使用统计
|
||||
const modelKeys = await client.keys(`usage:${apiKeyId}:model:*:*:*`);
|
||||
|
||||
const modelKeys = await client.keys(`usage:${apiKeyId}:model:*:*:*`)
|
||||
|
||||
// 按日期分组统计
|
||||
const dailyCosts = new Map(); // date -> cost
|
||||
const monthlyCosts = new Map(); // month -> cost
|
||||
const hourlyCosts = new Map(); // hour -> cost
|
||||
|
||||
const dailyCosts = new Map() // date -> cost
|
||||
const monthlyCosts = new Map() // month -> cost
|
||||
const hourlyCosts = new Map() // hour -> cost
|
||||
|
||||
for (const key of modelKeys) {
|
||||
// 解析key格式: usage:{keyId}:model:{period}:{model}:{date}
|
||||
const match = key.match(/usage:(.+):model:(daily|monthly|hourly):(.+):(\d{4}-\d{2}(?:-\d{2})?(?::\d{2})?)$/);
|
||||
if (!match) continue;
|
||||
|
||||
const [, , period, model, dateStr] = match;
|
||||
|
||||
const match = key.match(
|
||||
/usage:(.+):model:(daily|monthly|hourly):(.+):(\d{4}-\d{2}(?:-\d{2})?(?::\d{2})?)$/
|
||||
)
|
||||
if (!match) {
|
||||
continue
|
||||
}
|
||||
|
||||
const [, , period, model, dateStr] = match
|
||||
|
||||
// 获取使用数据
|
||||
const data = await client.hgetall(key);
|
||||
if (!data || Object.keys(data).length === 0) continue;
|
||||
|
||||
const data = await client.hgetall(key)
|
||||
if (!data || Object.keys(data).length === 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算费用
|
||||
const usage = {
|
||||
input_tokens: parseInt(data.totalInputTokens) || parseInt(data.inputTokens) || 0,
|
||||
output_tokens: parseInt(data.totalOutputTokens) || parseInt(data.outputTokens) || 0,
|
||||
cache_creation_input_tokens: parseInt(data.totalCacheCreateTokens) || parseInt(data.cacheCreateTokens) || 0,
|
||||
cache_read_input_tokens: parseInt(data.totalCacheReadTokens) || parseInt(data.cacheReadTokens) || 0
|
||||
};
|
||||
|
||||
const costResult = CostCalculator.calculateCost(usage, model);
|
||||
const cost = costResult.costs.total;
|
||||
|
||||
cache_creation_input_tokens:
|
||||
parseInt(data.totalCacheCreateTokens) || parseInt(data.cacheCreateTokens) || 0,
|
||||
cache_read_input_tokens:
|
||||
parseInt(data.totalCacheReadTokens) || parseInt(data.cacheReadTokens) || 0
|
||||
}
|
||||
|
||||
const costResult = CostCalculator.calculateCost(usage, model)
|
||||
const cost = costResult.costs.total
|
||||
|
||||
// 根据period分组累加费用
|
||||
if (period === 'daily') {
|
||||
const currentCost = dailyCosts.get(dateStr) || 0;
|
||||
dailyCosts.set(dateStr, currentCost + cost);
|
||||
const currentCost = dailyCosts.get(dateStr) || 0
|
||||
dailyCosts.set(dateStr, currentCost + cost)
|
||||
} else if (period === 'monthly') {
|
||||
const currentCost = monthlyCosts.get(dateStr) || 0;
|
||||
monthlyCosts.set(dateStr, currentCost + cost);
|
||||
const currentCost = monthlyCosts.get(dateStr) || 0
|
||||
monthlyCosts.set(dateStr, currentCost + cost)
|
||||
} else if (period === 'hourly') {
|
||||
const currentCost = hourlyCosts.get(dateStr) || 0;
|
||||
hourlyCosts.set(dateStr, currentCost + cost);
|
||||
const currentCost = hourlyCosts.get(dateStr) || 0
|
||||
hourlyCosts.set(dateStr, currentCost + cost)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 将计算出的费用写入Redis
|
||||
const promises = [];
|
||||
|
||||
const promises = []
|
||||
|
||||
// 写入每日费用
|
||||
for (const [date, cost] of dailyCosts) {
|
||||
const key = `usage:cost:daily:${apiKeyId}:${date}`;
|
||||
const key = `usage:cost:daily:${apiKeyId}:${date}`
|
||||
promises.push(
|
||||
client.set(key, cost.toString()),
|
||||
client.expire(key, 86400 * 30) // 30天过期
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// 写入每月费用
|
||||
for (const [month, cost] of monthlyCosts) {
|
||||
const key = `usage:cost:monthly:${apiKeyId}:${month}`;
|
||||
const key = `usage:cost:monthly:${apiKeyId}:${month}`
|
||||
promises.push(
|
||||
client.set(key, cost.toString()),
|
||||
client.expire(key, 86400 * 90) // 90天过期
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// 写入每小时费用
|
||||
for (const [hour, cost] of hourlyCosts) {
|
||||
const key = `usage:cost:hourly:${apiKeyId}:${hour}`;
|
||||
const key = `usage:cost:hourly:${apiKeyId}:${hour}`
|
||||
promises.push(
|
||||
client.set(key, cost.toString()),
|
||||
client.expire(key, 86400 * 7) // 7天过期
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// 计算总费用
|
||||
let totalCost = 0;
|
||||
let totalCost = 0
|
||||
for (const cost of dailyCosts.values()) {
|
||||
totalCost += cost;
|
||||
totalCost += cost
|
||||
}
|
||||
|
||||
|
||||
// 写入总费用
|
||||
if (totalCost > 0) {
|
||||
const totalKey = `usage:cost:total:${apiKeyId}`;
|
||||
promises.push(client.set(totalKey, totalCost.toString()));
|
||||
const totalKey = `usage:cost:total:${apiKeyId}`
|
||||
promises.push(client.set(totalKey, totalCost.toString()))
|
||||
}
|
||||
|
||||
await Promise.all(promises);
|
||||
|
||||
logger.debug(`💰 Initialized costs for API Key ${apiKeyId}: Daily entries: ${dailyCosts.size}, Total cost: $${totalCost.toFixed(2)}`);
|
||||
|
||||
await Promise.all(promises)
|
||||
|
||||
logger.debug(
|
||||
`💰 Initialized costs for API Key ${apiKeyId}: Daily entries: ${dailyCosts.size}, Total cost: $${totalCost.toFixed(2)}`
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 检查是否需要初始化费用数据
|
||||
*/
|
||||
async needsInitialization() {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
|
||||
// 检查是否有任何费用数据
|
||||
const costKeys = await client.keys('usage:cost:*');
|
||||
|
||||
const costKeys = await client.keys('usage:cost:*')
|
||||
|
||||
// 如果没有费用数据,需要初始化
|
||||
if (costKeys.length === 0) {
|
||||
logger.info('💰 No cost data found, initialization needed');
|
||||
return true;
|
||||
logger.info('💰 No cost data found, initialization needed')
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
// 检查是否有使用数据但没有对应的费用数据
|
||||
const sampleKeys = await client.keys('usage:*:model:daily:*:*');
|
||||
const sampleKeys = await client.keys('usage:*:model:daily:*:*')
|
||||
if (sampleKeys.length > 10) {
|
||||
// 抽样检查
|
||||
const sampleSize = Math.min(10, sampleKeys.length);
|
||||
const sampleSize = Math.min(10, sampleKeys.length)
|
||||
for (let i = 0; i < sampleSize; i++) {
|
||||
const usageKey = sampleKeys[Math.floor(Math.random() * sampleKeys.length)];
|
||||
const match = usageKey.match(/usage:(.+):model:daily:(.+):(\d{4}-\d{2}-\d{2})$/);
|
||||
const usageKey = sampleKeys[Math.floor(Math.random() * sampleKeys.length)]
|
||||
const match = usageKey.match(/usage:(.+):model:daily:(.+):(\d{4}-\d{2}-\d{2})$/)
|
||||
if (match) {
|
||||
const [, keyId, , date] = match;
|
||||
const costKey = `usage:cost:daily:${keyId}:${date}`;
|
||||
const hasCost = await client.exists(costKey);
|
||||
const [, keyId, , date] = match
|
||||
const costKey = `usage:cost:daily:${keyId}:${date}`
|
||||
const hasCost = await client.exists(costKey)
|
||||
if (!hasCost) {
|
||||
logger.info(`💰 Found usage without cost data for key ${keyId} on ${date}, initialization needed`);
|
||||
return true;
|
||||
logger.info(
|
||||
`💰 Found usage without cost data for key ${keyId} on ${date}, initialization needed`
|
||||
)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('💰 Cost data appears to be up to date');
|
||||
return false;
|
||||
|
||||
logger.info('💰 Cost data appears to be up to date')
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to check initialization status:', error);
|
||||
return false;
|
||||
logger.error('❌ Failed to check initialization status:', error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new CostInitService();
|
||||
module.exports = new CostInitService()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,228 +1,243 @@
|
||||
const axios = require('axios');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const { SocksProxyAgent } = require('socks-proxy-agent');
|
||||
const logger = require('../utils/logger');
|
||||
const config = require('../../config/config');
|
||||
const apiKeyService = require('./apiKeyService');
|
||||
const axios = require('axios')
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent')
|
||||
const { SocksProxyAgent } = require('socks-proxy-agent')
|
||||
const logger = require('../utils/logger')
|
||||
const config = require('../../config/config')
|
||||
const apiKeyService = require('./apiKeyService')
|
||||
|
||||
// Gemini API 配置
|
||||
const GEMINI_API_BASE = 'https://cloudcode.googleapis.com/v1';
|
||||
const DEFAULT_MODEL = 'models/gemini-2.0-flash-exp';
|
||||
const GEMINI_API_BASE = 'https://cloudcode.googleapis.com/v1'
|
||||
const DEFAULT_MODEL = 'models/gemini-2.0-flash-exp'
|
||||
|
||||
// 创建代理 agent
|
||||
function createProxyAgent(proxyConfig) {
|
||||
if (!proxyConfig) return null;
|
||||
|
||||
if (!proxyConfig) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const proxy = typeof proxyConfig === 'string' ? JSON.parse(proxyConfig) : proxyConfig;
|
||||
|
||||
const proxy = typeof proxyConfig === 'string' ? JSON.parse(proxyConfig) : proxyConfig
|
||||
|
||||
if (!proxy.type || !proxy.host || !proxy.port) {
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
const proxyUrl = proxy.username && proxy.password
|
||||
? `${proxy.type}://${proxy.username}:${proxy.password}@${proxy.host}:${proxy.port}`
|
||||
: `${proxy.type}://${proxy.host}:${proxy.port}`;
|
||||
|
||||
|
||||
const proxyUrl =
|
||||
proxy.username && proxy.password
|
||||
? `${proxy.type}://${proxy.username}:${proxy.password}@${proxy.host}:${proxy.port}`
|
||||
: `${proxy.type}://${proxy.host}:${proxy.port}`
|
||||
|
||||
if (proxy.type === 'socks5') {
|
||||
return new SocksProxyAgent(proxyUrl);
|
||||
return new SocksProxyAgent(proxyUrl)
|
||||
} else if (proxy.type === 'http' || proxy.type === 'https') {
|
||||
return new HttpsProxyAgent(proxyUrl);
|
||||
return new HttpsProxyAgent(proxyUrl)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error creating proxy agent:', error);
|
||||
logger.error('Error creating proxy agent:', error)
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
// 转换 OpenAI 消息格式到 Gemini 格式
|
||||
function convertMessagesToGemini(messages) {
|
||||
const contents = [];
|
||||
let systemInstruction = '';
|
||||
|
||||
const contents = []
|
||||
let systemInstruction = ''
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.role === 'system') {
|
||||
systemInstruction += (systemInstruction ? '\n\n' : '') + message.content;
|
||||
systemInstruction += (systemInstruction ? '\n\n' : '') + message.content
|
||||
} else if (message.role === 'user') {
|
||||
contents.push({
|
||||
role: 'user',
|
||||
parts: [{ text: message.content }]
|
||||
});
|
||||
})
|
||||
} else if (message.role === 'assistant') {
|
||||
contents.push({
|
||||
role: 'model',
|
||||
parts: [{ text: message.content }]
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return { contents, systemInstruction };
|
||||
|
||||
return { contents, systemInstruction }
|
||||
}
|
||||
|
||||
// 转换 Gemini 响应到 OpenAI 格式
|
||||
function convertGeminiResponse(geminiResponse, model, stream = false) {
|
||||
if (stream) {
|
||||
// 流式响应
|
||||
const candidate = geminiResponse.candidates?.[0];
|
||||
if (!candidate) return null;
|
||||
|
||||
const content = candidate.content?.parts?.[0]?.text || '';
|
||||
const finishReason = candidate.finishReason?.toLowerCase();
|
||||
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
if (!candidate) {
|
||||
return null
|
||||
}
|
||||
|
||||
const content = candidate.content?.parts?.[0]?.text || ''
|
||||
const finishReason = candidate.finishReason?.toLowerCase()
|
||||
|
||||
return {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion.chunk',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: {
|
||||
content: content
|
||||
},
|
||||
finish_reason: finishReason === 'stop' ? 'stop' : null
|
||||
}]
|
||||
};
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
content
|
||||
},
|
||||
finish_reason: finishReason === 'stop' ? 'stop' : null
|
||||
}
|
||||
]
|
||||
}
|
||||
} else {
|
||||
// 非流式响应
|
||||
const candidate = geminiResponse.candidates?.[0];
|
||||
const candidate = geminiResponse.candidates?.[0]
|
||||
if (!candidate) {
|
||||
throw new Error('No response from Gemini');
|
||||
throw new Error('No response from Gemini')
|
||||
}
|
||||
|
||||
const content = candidate.content?.parts?.[0]?.text || '';
|
||||
const finishReason = candidate.finishReason?.toLowerCase() || 'stop';
|
||||
|
||||
|
||||
const content = candidate.content?.parts?.[0]?.text || ''
|
||||
const finishReason = candidate.finishReason?.toLowerCase() || 'stop'
|
||||
|
||||
// 计算 token 使用量
|
||||
const usage = geminiResponse.usageMetadata || {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
return {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: 'chat.completion',
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model,
|
||||
choices: [{
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: content
|
||||
},
|
||||
finish_reason: finishReason
|
||||
}],
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content
|
||||
},
|
||||
finish_reason: finishReason
|
||||
}
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: usage.promptTokenCount,
|
||||
completion_tokens: usage.candidatesTokenCount,
|
||||
total_tokens: usage.totalTokenCount
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理流式响应
|
||||
async function* handleStreamResponse(response, model, apiKeyId, accountId = null) {
|
||||
let buffer = '';
|
||||
let buffer = ''
|
||||
let totalUsage = {
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
try {
|
||||
for await (const chunk of response.data) {
|
||||
buffer += chunk.toString();
|
||||
|
||||
buffer += chunk.toString()
|
||||
|
||||
// 处理 SSE 格式的数据
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || ''; // 保留最后一个不完整的行
|
||||
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || '' // 保留最后一个不完整的行
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.trim()) continue;
|
||||
|
||||
// 处理 SSE 格式: "data: {...}"
|
||||
let jsonData = line;
|
||||
if (line.startsWith('data: ')) {
|
||||
jsonData = line.substring(6).trim();
|
||||
if (!line.trim()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (!jsonData || jsonData === '[DONE]') continue;
|
||||
// 处理 SSE 格式: "data: {...}"
|
||||
let jsonData = line
|
||||
if (line.startsWith('data: ')) {
|
||||
jsonData = line.substring(6).trim()
|
||||
}
|
||||
|
||||
if (!jsonData || jsonData === '[DONE]') {
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
const data = JSON.parse(jsonData);
|
||||
|
||||
const data = JSON.parse(jsonData)
|
||||
|
||||
// 更新使用量统计
|
||||
if (data.usageMetadata) {
|
||||
totalUsage = data.usageMetadata;
|
||||
totalUsage = data.usageMetadata
|
||||
}
|
||||
|
||||
|
||||
// 转换并发送响应
|
||||
const openaiResponse = convertGeminiResponse(data, model, true);
|
||||
const openaiResponse = convertGeminiResponse(data, model, true)
|
||||
if (openaiResponse) {
|
||||
yield `data: ${JSON.stringify(openaiResponse)}\n\n`;
|
||||
yield `data: ${JSON.stringify(openaiResponse)}\n\n`
|
||||
}
|
||||
|
||||
|
||||
// 检查是否结束
|
||||
if (data.candidates?.[0]?.finishReason === 'STOP') {
|
||||
// 记录使用量
|
||||
if (apiKeyId && totalUsage.totalTokenCount > 0) {
|
||||
await apiKeyService.recordUsage(
|
||||
apiKeyId,
|
||||
totalUsage.promptTokenCount || 0, // inputTokens
|
||||
totalUsage.candidatesTokenCount || 0, // outputTokens
|
||||
0, // cacheCreateTokens (Gemini 没有这个概念)
|
||||
0, // cacheReadTokens (Gemini 没有这个概念)
|
||||
model,
|
||||
accountId
|
||||
).catch(error => {
|
||||
logger.error('❌ Failed to record Gemini usage:', error);
|
||||
});
|
||||
await apiKeyService
|
||||
.recordUsage(
|
||||
apiKeyId,
|
||||
totalUsage.promptTokenCount || 0, // inputTokens
|
||||
totalUsage.candidatesTokenCount || 0, // outputTokens
|
||||
0, // cacheCreateTokens (Gemini 没有这个概念)
|
||||
0, // cacheReadTokens (Gemini 没有这个概念)
|
||||
model,
|
||||
accountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record Gemini usage:', error)
|
||||
})
|
||||
}
|
||||
|
||||
yield 'data: [DONE]\n\n';
|
||||
return;
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug('Error parsing JSON line:', e.message, 'Line:', jsonData);
|
||||
logger.debug('Error parsing JSON line:', e.message, 'Line:', jsonData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 处理剩余的 buffer
|
||||
if (buffer.trim()) {
|
||||
try {
|
||||
let jsonData = buffer.trim();
|
||||
let jsonData = buffer.trim()
|
||||
if (jsonData.startsWith('data: ')) {
|
||||
jsonData = jsonData.substring(6).trim();
|
||||
jsonData = jsonData.substring(6).trim()
|
||||
}
|
||||
|
||||
if (jsonData && jsonData !== '[DONE]') {
|
||||
const data = JSON.parse(jsonData);
|
||||
const openaiResponse = convertGeminiResponse(data, model, true);
|
||||
const data = JSON.parse(jsonData)
|
||||
const openaiResponse = convertGeminiResponse(data, model, true)
|
||||
if (openaiResponse) {
|
||||
yield `data: ${JSON.stringify(openaiResponse)}\n\n`;
|
||||
yield `data: ${JSON.stringify(openaiResponse)}\n\n`
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.debug('Error parsing final buffer:', e.message);
|
||||
logger.debug('Error parsing final buffer:', e.message)
|
||||
}
|
||||
}
|
||||
|
||||
yield 'data: [DONE]\n\n';
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
} catch (error) {
|
||||
// 检查是否是请求被中止
|
||||
if (error.name === 'CanceledError' || error.code === 'ECONNABORTED') {
|
||||
logger.info('Stream request was aborted by client');
|
||||
logger.info('Stream request was aborted by client')
|
||||
} else {
|
||||
logger.error('Stream processing error:', error);
|
||||
logger.error('Stream processing error:', error)
|
||||
yield `data: ${JSON.stringify({
|
||||
error: {
|
||||
message: error.message,
|
||||
type: 'stream_error'
|
||||
}
|
||||
})}\n\n`;
|
||||
})}\n\n`
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -244,12 +259,12 @@ async function sendGeminiRequest({
|
||||
}) {
|
||||
// 确保模型名称格式正确
|
||||
if (!model.startsWith('models/')) {
|
||||
model = `models/${model}`;
|
||||
model = `models/${model}`
|
||||
}
|
||||
|
||||
|
||||
// 转换消息格式
|
||||
const { contents, systemInstruction } = convertMessagesToGemini(messages);
|
||||
|
||||
const { contents, systemInstruction } = convertMessagesToGemini(messages)
|
||||
|
||||
// 构建请求体
|
||||
const requestBody = {
|
||||
contents,
|
||||
@@ -258,160 +273,162 @@ async function sendGeminiRequest({
|
||||
maxOutputTokens: maxTokens,
|
||||
candidateCount: 1
|
||||
}
|
||||
};
|
||||
|
||||
if (systemInstruction) {
|
||||
requestBody.systemInstruction = { parts: [{ text: systemInstruction }] };
|
||||
}
|
||||
|
||||
|
||||
if (systemInstruction) {
|
||||
requestBody.systemInstruction = { parts: [{ text: systemInstruction }] }
|
||||
}
|
||||
|
||||
// 配置请求选项
|
||||
let apiUrl;
|
||||
let apiUrl
|
||||
if (projectId) {
|
||||
// 使用项目特定的 URL 格式(Google Cloud/Workspace 账号)
|
||||
apiUrl = `${GEMINI_API_BASE}/projects/${projectId}/locations/${location}/${model}:${stream ? 'streamGenerateContent' : 'generateContent'}?alt=sse`;
|
||||
logger.debug(`Using project-specific URL with projectId: ${projectId}, location: ${location}`);
|
||||
apiUrl = `${GEMINI_API_BASE}/projects/${projectId}/locations/${location}/${model}:${stream ? 'streamGenerateContent' : 'generateContent'}?alt=sse`
|
||||
logger.debug(`Using project-specific URL with projectId: ${projectId}, location: ${location}`)
|
||||
} else {
|
||||
// 使用标准 URL 格式(个人 Google 账号)
|
||||
apiUrl = `${GEMINI_API_BASE}/${model}:${stream ? 'streamGenerateContent' : 'generateContent'}?alt=sse`;
|
||||
logger.debug('Using standard URL without projectId');
|
||||
apiUrl = `${GEMINI_API_BASE}/${model}:${stream ? 'streamGenerateContent' : 'generateContent'}?alt=sse`
|
||||
logger.debug('Using standard URL without projectId')
|
||||
}
|
||||
|
||||
|
||||
const axiosConfig = {
|
||||
method: 'POST',
|
||||
url: apiUrl,
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`,
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
data: requestBody,
|
||||
timeout: config.requestTimeout || 120000
|
||||
};
|
||||
|
||||
// 添加代理配置
|
||||
const proxyAgent = createProxyAgent(proxy);
|
||||
if (proxyAgent) {
|
||||
axiosConfig.httpsAgent = proxyAgent;
|
||||
logger.debug('Using proxy for Gemini request');
|
||||
}
|
||||
|
||||
|
||||
// 添加代理配置
|
||||
const proxyAgent = createProxyAgent(proxy)
|
||||
if (proxyAgent) {
|
||||
axiosConfig.httpsAgent = proxyAgent
|
||||
logger.debug('Using proxy for Gemini request')
|
||||
}
|
||||
|
||||
// 添加 AbortController 信号支持
|
||||
if (signal) {
|
||||
axiosConfig.signal = signal;
|
||||
logger.debug('AbortController signal attached to request');
|
||||
axiosConfig.signal = signal
|
||||
logger.debug('AbortController signal attached to request')
|
||||
}
|
||||
|
||||
|
||||
if (stream) {
|
||||
axiosConfig.responseType = 'stream';
|
||||
axiosConfig.responseType = 'stream'
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
logger.debug('Sending request to Gemini API');
|
||||
const response = await axios(axiosConfig);
|
||||
|
||||
logger.debug('Sending request to Gemini API')
|
||||
const response = await axios(axiosConfig)
|
||||
|
||||
if (stream) {
|
||||
return handleStreamResponse(response, model, apiKeyId, accountId);
|
||||
return handleStreamResponse(response, model, apiKeyId, accountId)
|
||||
} else {
|
||||
// 非流式响应
|
||||
const openaiResponse = convertGeminiResponse(response.data, model, false);
|
||||
|
||||
const openaiResponse = convertGeminiResponse(response.data, model, false)
|
||||
|
||||
// 记录使用量
|
||||
if (apiKeyId && openaiResponse.usage) {
|
||||
await apiKeyService.recordUsage(
|
||||
apiKeyId,
|
||||
openaiResponse.usage.prompt_tokens || 0,
|
||||
openaiResponse.usage.completion_tokens || 0,
|
||||
0, // cacheCreateTokens
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
accountId
|
||||
).catch(error => {
|
||||
logger.error('❌ Failed to record Gemini usage:', error);
|
||||
});
|
||||
await apiKeyService
|
||||
.recordUsage(
|
||||
apiKeyId,
|
||||
openaiResponse.usage.prompt_tokens || 0,
|
||||
openaiResponse.usage.completion_tokens || 0,
|
||||
0, // cacheCreateTokens
|
||||
0, // cacheReadTokens
|
||||
model,
|
||||
accountId
|
||||
)
|
||||
.catch((error) => {
|
||||
logger.error('❌ Failed to record Gemini usage:', error)
|
||||
})
|
||||
}
|
||||
|
||||
return openaiResponse;
|
||||
|
||||
return openaiResponse
|
||||
}
|
||||
} catch (error) {
|
||||
// 检查是否是请求被中止
|
||||
if (error.name === 'CanceledError' || error.code === 'ECONNABORTED') {
|
||||
logger.info('Gemini request was aborted by client');
|
||||
throw {
|
||||
status: 499,
|
||||
error: {
|
||||
message: 'Request canceled by client',
|
||||
type: 'canceled',
|
||||
code: 'request_canceled'
|
||||
}
|
||||
};
|
||||
logger.info('Gemini request was aborted by client')
|
||||
const err = new Error('Request canceled by client')
|
||||
err.status = 499
|
||||
err.error = {
|
||||
message: 'Request canceled by client',
|
||||
type: 'canceled',
|
||||
code: 'request_canceled'
|
||||
}
|
||||
throw err
|
||||
}
|
||||
|
||||
logger.error('Gemini API request failed:', error.response?.data || error.message);
|
||||
|
||||
|
||||
logger.error('Gemini API request failed:', error.response?.data || error.message)
|
||||
|
||||
// 转换错误格式
|
||||
if (error.response) {
|
||||
const geminiError = error.response.data?.error;
|
||||
throw {
|
||||
status: error.response.status,
|
||||
error: {
|
||||
message: geminiError?.message || 'Gemini API request failed',
|
||||
type: geminiError?.code || 'api_error',
|
||||
code: geminiError?.code
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
throw {
|
||||
status: 500,
|
||||
error: {
|
||||
message: error.message,
|
||||
type: 'network_error'
|
||||
const geminiError = error.response.data?.error
|
||||
const err = new Error(geminiError?.message || 'Gemini API request failed')
|
||||
err.status = error.response.status
|
||||
err.error = {
|
||||
message: geminiError?.message || 'Gemini API request failed',
|
||||
type: geminiError?.code || 'api_error',
|
||||
code: geminiError?.code
|
||||
}
|
||||
};
|
||||
throw err
|
||||
}
|
||||
|
||||
const err = new Error(error.message)
|
||||
err.status = 500
|
||||
err.error = {
|
||||
message: error.message,
|
||||
type: 'network_error'
|
||||
}
|
||||
throw err
|
||||
}
|
||||
}
|
||||
|
||||
// 获取可用模型列表
|
||||
async function getAvailableModels(accessToken, proxy, projectId, location = 'us-central1') {
|
||||
let apiUrl;
|
||||
let apiUrl
|
||||
if (projectId) {
|
||||
// 使用项目特定的 URL 格式
|
||||
apiUrl = `${GEMINI_API_BASE}/projects/${projectId}/locations/${location}/models`;
|
||||
logger.debug(`Fetching models with projectId: ${projectId}, location: ${location}`);
|
||||
apiUrl = `${GEMINI_API_BASE}/projects/${projectId}/locations/${location}/models`
|
||||
logger.debug(`Fetching models with projectId: ${projectId}, location: ${location}`)
|
||||
} else {
|
||||
// 使用标准 URL 格式
|
||||
apiUrl = `${GEMINI_API_BASE}/models`;
|
||||
logger.debug('Fetching models without projectId');
|
||||
apiUrl = `${GEMINI_API_BASE}/models`
|
||||
logger.debug('Fetching models without projectId')
|
||||
}
|
||||
|
||||
|
||||
const axiosConfig = {
|
||||
method: 'GET',
|
||||
url: apiUrl,
|
||||
headers: {
|
||||
'Authorization': `Bearer ${accessToken}`
|
||||
Authorization: `Bearer ${accessToken}`
|
||||
},
|
||||
timeout: 30000
|
||||
};
|
||||
|
||||
const proxyAgent = createProxyAgent(proxy);
|
||||
if (proxyAgent) {
|
||||
axiosConfig.httpsAgent = proxyAgent;
|
||||
}
|
||||
|
||||
|
||||
const proxyAgent = createProxyAgent(proxy)
|
||||
if (proxyAgent) {
|
||||
axiosConfig.httpsAgent = proxyAgent
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios(axiosConfig);
|
||||
const models = response.data.models || [];
|
||||
|
||||
const response = await axios(axiosConfig)
|
||||
const models = response.data.models || []
|
||||
|
||||
// 转换为 OpenAI 格式
|
||||
return models
|
||||
.filter(model => model.supportedGenerationMethods?.includes('generateContent'))
|
||||
.map(model => ({
|
||||
.filter((model) => model.supportedGenerationMethods?.includes('generateContent'))
|
||||
.map((model) => ({
|
||||
id: model.name.replace('models/', ''),
|
||||
object: 'model',
|
||||
created: Date.now() / 1000,
|
||||
owned_by: 'google'
|
||||
}));
|
||||
}))
|
||||
} catch (error) {
|
||||
logger.error('Failed to get Gemini models:', error);
|
||||
logger.error('Failed to get Gemini models:', error)
|
||||
// 返回默认模型列表
|
||||
return [
|
||||
{
|
||||
@@ -420,7 +437,7 @@ async function getAvailableModels(accessToken, proxy, projectId, location = 'us-
|
||||
created: Date.now() / 1000,
|
||||
owned_by: 'google'
|
||||
}
|
||||
];
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -429,4 +446,4 @@ module.exports = {
|
||||
getAvailableModels,
|
||||
convertMessagesToGemini,
|
||||
convertGeminiResponse
|
||||
};
|
||||
}
|
||||
|
||||
@@ -3,17 +3,17 @@
|
||||
* 处理 OpenAI API 格式与 Claude API 格式之间的转换
|
||||
*/
|
||||
|
||||
const logger = require('../utils/logger');
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class OpenAIToClaudeConverter {
|
||||
constructor() {
|
||||
// 停止原因映射
|
||||
this.stopReasonMapping = {
|
||||
'end_turn': 'stop',
|
||||
'max_tokens': 'length',
|
||||
'stop_sequence': 'stop',
|
||||
'tool_use': 'tool_calls'
|
||||
};
|
||||
end_turn: 'stop',
|
||||
max_tokens: 'length',
|
||||
stop_sequence: 'stop',
|
||||
tool_use: 'tool_calls'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -29,39 +29,39 @@ class OpenAIToClaudeConverter {
|
||||
temperature: openaiRequest.temperature,
|
||||
top_p: openaiRequest.top_p,
|
||||
stream: openaiRequest.stream || false
|
||||
};
|
||||
}
|
||||
|
||||
// Claude Code 必需的系统消息
|
||||
const claudeCodeSystemMessage = 'You are Claude Code, Anthropic\'s official CLI for Claude.';
|
||||
|
||||
claudeRequest.system = claudeCodeSystemMessage;
|
||||
const claudeCodeSystemMessage = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
|
||||
claudeRequest.system = claudeCodeSystemMessage
|
||||
|
||||
// 处理停止序列
|
||||
if (openaiRequest.stop) {
|
||||
claudeRequest.stop_sequences = Array.isArray(openaiRequest.stop)
|
||||
? openaiRequest.stop
|
||||
: [openaiRequest.stop];
|
||||
claudeRequest.stop_sequences = Array.isArray(openaiRequest.stop)
|
||||
? openaiRequest.stop
|
||||
: [openaiRequest.stop]
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if (openaiRequest.tools) {
|
||||
claudeRequest.tools = this._convertTools(openaiRequest.tools);
|
||||
claudeRequest.tools = this._convertTools(openaiRequest.tools)
|
||||
if (openaiRequest.tool_choice) {
|
||||
claudeRequest.tool_choice = this._convertToolChoice(openaiRequest.tool_choice);
|
||||
claudeRequest.tool_choice = this._convertToolChoice(openaiRequest.tool_choice)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI 特有的参数已在转换过程中被忽略
|
||||
// 包括: n, presence_penalty, frequency_penalty, logit_bias, user
|
||||
|
||||
|
||||
logger.debug('📝 Converted OpenAI request to Claude format:', {
|
||||
model: claudeRequest.model,
|
||||
messageCount: claudeRequest.messages.length,
|
||||
hasSystem: !!claudeRequest.system,
|
||||
stream: claudeRequest.stream
|
||||
});
|
||||
})
|
||||
|
||||
return claudeRequest;
|
||||
return claudeRequest
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -71,28 +71,30 @@ class OpenAIToClaudeConverter {
|
||||
* @returns {Object} OpenAI 格式的响应
|
||||
*/
|
||||
convertResponse(claudeResponse, requestModel) {
|
||||
const timestamp = Math.floor(Date.now() / 1000);
|
||||
|
||||
const timestamp = Math.floor(Date.now() / 1000)
|
||||
|
||||
const openaiResponse = {
|
||||
id: `chatcmpl-${this._generateId()}`,
|
||||
object: 'chat.completion',
|
||||
created: timestamp,
|
||||
model: requestModel || 'gpt-4',
|
||||
choices: [{
|
||||
index: 0,
|
||||
message: this._convertClaudeMessage(claudeResponse),
|
||||
finish_reason: this._mapStopReason(claudeResponse.stop_reason)
|
||||
}],
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: this._convertClaudeMessage(claudeResponse),
|
||||
finish_reason: this._mapStopReason(claudeResponse.stop_reason)
|
||||
}
|
||||
],
|
||||
usage: this._convertUsage(claudeResponse.usage)
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug('📝 Converted Claude response to OpenAI format:', {
|
||||
responseId: openaiResponse.id,
|
||||
finishReason: openaiResponse.choices[0].finish_reason,
|
||||
usage: openaiResponse.usage
|
||||
});
|
||||
})
|
||||
|
||||
return openaiResponse;
|
||||
return openaiResponse
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -103,36 +105,38 @@ class OpenAIToClaudeConverter {
|
||||
* @returns {String} OpenAI 格式的 SSE 数据块
|
||||
*/
|
||||
convertStreamChunk(chunk, requestModel, sessionId) {
|
||||
if (!chunk || chunk.trim() === '') return '';
|
||||
|
||||
if (!chunk || chunk.trim() === '') {
|
||||
return ''
|
||||
}
|
||||
|
||||
// 解析 SSE 数据
|
||||
const lines = chunk.split('\n');
|
||||
let convertedChunks = [];
|
||||
let hasMessageStop = false;
|
||||
const lines = chunk.split('\n')
|
||||
const convertedChunks = []
|
||||
let hasMessageStop = false
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.substring(6);
|
||||
const data = line.substring(6)
|
||||
if (data === '[DONE]') {
|
||||
convertedChunks.push('data: [DONE]\n\n');
|
||||
continue;
|
||||
convertedChunks.push('data: [DONE]\n\n')
|
||||
continue
|
||||
}
|
||||
|
||||
try {
|
||||
const claudeEvent = JSON.parse(data);
|
||||
|
||||
const claudeEvent = JSON.parse(data)
|
||||
|
||||
// 检查是否是 message_stop 事件
|
||||
if (claudeEvent.type === 'message_stop') {
|
||||
hasMessageStop = true;
|
||||
hasMessageStop = true
|
||||
}
|
||||
|
||||
const openaiChunk = this._convertStreamEvent(claudeEvent, requestModel, sessionId);
|
||||
|
||||
const openaiChunk = this._convertStreamEvent(claudeEvent, requestModel, sessionId)
|
||||
if (openaiChunk) {
|
||||
convertedChunks.push(`data: ${JSON.stringify(openaiChunk)}\n\n`);
|
||||
convertedChunks.push(`data: ${JSON.stringify(openaiChunk)}\n\n`)
|
||||
}
|
||||
} catch (e) {
|
||||
// 跳过无法解析的数据,不传递非JSON格式的行
|
||||
continue;
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 忽略 event: 行和空行,OpenAI 格式不包含这些
|
||||
@@ -140,95 +144,102 @@ class OpenAIToClaudeConverter {
|
||||
|
||||
// 如果收到 message_stop 事件,添加 [DONE] 标记
|
||||
if (hasMessageStop) {
|
||||
convertedChunks.push('data: [DONE]\n\n');
|
||||
convertedChunks.push('data: [DONE]\n\n')
|
||||
}
|
||||
|
||||
return convertedChunks.join('');
|
||||
return convertedChunks.join('')
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 提取系统消息
|
||||
*/
|
||||
_extractSystemMessage(messages) {
|
||||
const systemMessages = messages.filter(msg => msg.role === 'system');
|
||||
if (systemMessages.length === 0) return null;
|
||||
|
||||
const systemMessages = messages.filter((msg) => msg.role === 'system')
|
||||
if (systemMessages.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
// 合并所有系统消息
|
||||
return systemMessages.map(msg => msg.content).join('\n\n');
|
||||
return systemMessages.map((msg) => msg.content).join('\n\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换消息格式
|
||||
*/
|
||||
_convertMessages(messages) {
|
||||
const claudeMessages = [];
|
||||
|
||||
const claudeMessages = []
|
||||
|
||||
for (const msg of messages) {
|
||||
// 跳过系统消息(已经在 system 字段处理)
|
||||
if (msg.role === 'system') continue;
|
||||
|
||||
// 转换角色名称
|
||||
const role = msg.role === 'user' ? 'user' : 'assistant';
|
||||
|
||||
// 转换消息内容
|
||||
let content;
|
||||
if (typeof msg.content === 'string') {
|
||||
content = msg.content;
|
||||
} else if (Array.isArray(msg.content)) {
|
||||
// 处理多模态内容
|
||||
content = this._convertMultimodalContent(msg.content);
|
||||
} else {
|
||||
content = JSON.stringify(msg.content);
|
||||
if (msg.role === 'system') {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// 转换角色名称
|
||||
const role = msg.role === 'user' ? 'user' : 'assistant'
|
||||
|
||||
// 转换消息内容
|
||||
const { content: rawContent } = msg
|
||||
let content
|
||||
|
||||
if (typeof rawContent === 'string') {
|
||||
content = rawContent
|
||||
} else if (Array.isArray(rawContent)) {
|
||||
// 处理多模态内容
|
||||
content = this._convertMultimodalContent(rawContent)
|
||||
} else {
|
||||
content = JSON.stringify(rawContent)
|
||||
}
|
||||
|
||||
const claudeMsg = {
|
||||
role: role,
|
||||
content: content
|
||||
};
|
||||
|
||||
role,
|
||||
content
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if (msg.tool_calls) {
|
||||
claudeMsg.content = this._convertToolCalls(msg.tool_calls);
|
||||
claudeMsg.content = this._convertToolCalls(msg.tool_calls)
|
||||
}
|
||||
|
||||
|
||||
// 处理工具响应
|
||||
if (msg.role === 'tool') {
|
||||
claudeMsg.role = 'user';
|
||||
claudeMsg.content = [{
|
||||
type: 'tool_result',
|
||||
tool_use_id: msg.tool_call_id,
|
||||
content: msg.content
|
||||
}];
|
||||
claudeMsg.role = 'user'
|
||||
claudeMsg.content = [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: msg.tool_call_id,
|
||||
content: msg.content
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
claudeMessages.push(claudeMsg);
|
||||
|
||||
claudeMessages.push(claudeMsg)
|
||||
}
|
||||
|
||||
return claudeMessages;
|
||||
|
||||
return claudeMessages
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换多模态内容
|
||||
*/
|
||||
_convertMultimodalContent(content) {
|
||||
return content.map(item => {
|
||||
return content.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text
|
||||
};
|
||||
}
|
||||
} else if (item.type === 'image_url') {
|
||||
const imageUrl = item.image_url.url;
|
||||
|
||||
const imageUrl = item.image_url.url
|
||||
|
||||
// 检查是否是 base64 格式的图片
|
||||
if (imageUrl.startsWith('data:')) {
|
||||
// 解析 data URL: ...
|
||||
const matches = imageUrl.match(/^data:([^;]+);base64,(.+)$/);
|
||||
const matches = imageUrl.match(/^data:([^;]+);base64,(.+)$/)
|
||||
if (matches) {
|
||||
const mediaType = matches[1]; // e.g., 'image/jpeg', 'image/png'
|
||||
const base64Data = matches[2];
|
||||
|
||||
const mediaType = matches[1] // e.g., 'image/jpeg', 'image/png'
|
||||
const base64Data = matches[2]
|
||||
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
@@ -236,10 +247,10 @@ class OpenAIToClaudeConverter {
|
||||
media_type: mediaType,
|
||||
data: base64Data
|
||||
}
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// 如果格式不正确,尝试使用默认处理
|
||||
logger.warn('⚠️ Invalid base64 image format, using default parsing');
|
||||
logger.warn('⚠️ Invalid base64 image format, using default parsing')
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
@@ -247,60 +258,70 @@ class OpenAIToClaudeConverter {
|
||||
media_type: 'image/jpeg',
|
||||
data: imageUrl.split(',')[1] || ''
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果是 URL 格式的图片,Claude 不支持直接 URL,需要报错
|
||||
logger.error('❌ URL images are not supported by Claude API, only base64 format is accepted');
|
||||
throw new Error('Claude API only supports base64 encoded images, not URLs. Please convert the image to base64 format.');
|
||||
logger.error(
|
||||
'❌ URL images are not supported by Claude API, only base64 format is accepted'
|
||||
)
|
||||
throw new Error(
|
||||
'Claude API only supports base64 encoded images, not URLs. Please convert the image to base64 format.'
|
||||
)
|
||||
}
|
||||
}
|
||||
return item;
|
||||
});
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换工具定义
|
||||
*/
|
||||
_convertTools(tools) {
|
||||
return tools.map(tool => {
|
||||
return tools.map((tool) => {
|
||||
if (tool.type === 'function') {
|
||||
return {
|
||||
name: tool.function.name,
|
||||
description: tool.function.description,
|
||||
input_schema: tool.function.parameters
|
||||
};
|
||||
}
|
||||
}
|
||||
return tool;
|
||||
});
|
||||
return tool
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换工具选择
|
||||
*/
|
||||
_convertToolChoice(toolChoice) {
|
||||
if (toolChoice === 'none') return { type: 'none' };
|
||||
if (toolChoice === 'auto') return { type: 'auto' };
|
||||
if (toolChoice === 'required') return { type: 'any' };
|
||||
if (toolChoice === 'none') {
|
||||
return { type: 'none' }
|
||||
}
|
||||
if (toolChoice === 'auto') {
|
||||
return { type: 'auto' }
|
||||
}
|
||||
if (toolChoice === 'required') {
|
||||
return { type: 'any' }
|
||||
}
|
||||
if (toolChoice.type === 'function') {
|
||||
return {
|
||||
type: 'tool',
|
||||
name: toolChoice.function.name
|
||||
};
|
||||
}
|
||||
}
|
||||
return { type: 'auto' };
|
||||
return { type: 'auto' }
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换工具调用
|
||||
*/
|
||||
_convertToolCalls(toolCalls) {
|
||||
return toolCalls.map(tc => ({
|
||||
return toolCalls.map((tc) => ({
|
||||
type: 'tool_use',
|
||||
id: tc.id,
|
||||
name: tc.function.name,
|
||||
input: JSON.parse(tc.function.arguments)
|
||||
}));
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -310,20 +331,20 @@ class OpenAIToClaudeConverter {
|
||||
const message = {
|
||||
role: 'assistant',
|
||||
content: null
|
||||
};
|
||||
}
|
||||
|
||||
// 处理内容
|
||||
if (claudeResponse.content) {
|
||||
if (typeof claudeResponse.content === 'string') {
|
||||
message.content = claudeResponse.content;
|
||||
message.content = claudeResponse.content
|
||||
} else if (Array.isArray(claudeResponse.content)) {
|
||||
// 提取文本内容和工具调用
|
||||
const textParts = [];
|
||||
const toolCalls = [];
|
||||
|
||||
const textParts = []
|
||||
const toolCalls = []
|
||||
|
||||
for (const item of claudeResponse.content) {
|
||||
if (item.type === 'text') {
|
||||
textParts.push(item.text);
|
||||
textParts.push(item.text)
|
||||
} else if (item.type === 'tool_use') {
|
||||
toolCalls.push({
|
||||
id: item.id,
|
||||
@@ -332,114 +353,121 @@ class OpenAIToClaudeConverter {
|
||||
name: item.name,
|
||||
arguments: JSON.stringify(item.input)
|
||||
}
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
message.content = textParts.join('') || null;
|
||||
|
||||
message.content = textParts.join('') || null
|
||||
if (toolCalls.length > 0) {
|
||||
message.tool_calls = toolCalls;
|
||||
message.tool_calls = toolCalls
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return message;
|
||||
return message
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换停止原因
|
||||
*/
|
||||
_mapStopReason(claudeReason) {
|
||||
return this.stopReasonMapping[claudeReason] || 'stop';
|
||||
return this.stopReasonMapping[claudeReason] || 'stop'
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换使用统计
|
||||
*/
|
||||
_convertUsage(claudeUsage) {
|
||||
if (!claudeUsage) return undefined;
|
||||
|
||||
if (!claudeUsage) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
prompt_tokens: claudeUsage.input_tokens || 0,
|
||||
completion_tokens: claudeUsage.output_tokens || 0,
|
||||
total_tokens: (claudeUsage.input_tokens || 0) + (claudeUsage.output_tokens || 0)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换流式事件
|
||||
*/
|
||||
_convertStreamEvent(event, requestModel, sessionId) {
|
||||
const timestamp = Math.floor(Date.now() / 1000);
|
||||
const timestamp = Math.floor(Date.now() / 1000)
|
||||
const baseChunk = {
|
||||
id: sessionId,
|
||||
object: 'chat.completion.chunk',
|
||||
created: timestamp,
|
||||
model: requestModel || 'gpt-4',
|
||||
choices: [{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: null
|
||||
}]
|
||||
};
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: null
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
// 根据事件类型处理
|
||||
if (event.type === 'message_start') {
|
||||
// 处理消息开始事件,发送角色信息
|
||||
baseChunk.choices[0].delta.role = 'assistant';
|
||||
return baseChunk;
|
||||
baseChunk.choices[0].delta.role = 'assistant'
|
||||
return baseChunk
|
||||
} else if (event.type === 'content_block_start' && event.content_block) {
|
||||
if (event.content_block.type === 'text') {
|
||||
baseChunk.choices[0].delta.content = event.content_block.text || '';
|
||||
baseChunk.choices[0].delta.content = event.content_block.text || ''
|
||||
} else if (event.content_block.type === 'tool_use') {
|
||||
// 开始工具调用
|
||||
baseChunk.choices[0].delta.tool_calls = [{
|
||||
index: event.index || 0,
|
||||
id: event.content_block.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: event.content_block.name,
|
||||
arguments: ''
|
||||
baseChunk.choices[0].delta.tool_calls = [
|
||||
{
|
||||
index: event.index || 0,
|
||||
id: event.content_block.id,
|
||||
type: 'function',
|
||||
function: {
|
||||
name: event.content_block.name,
|
||||
arguments: ''
|
||||
}
|
||||
}
|
||||
}];
|
||||
]
|
||||
}
|
||||
} else if (event.type === 'content_block_delta' && event.delta) {
|
||||
if (event.delta.type === 'text_delta') {
|
||||
baseChunk.choices[0].delta.content = event.delta.text || '';
|
||||
baseChunk.choices[0].delta.content = event.delta.text || ''
|
||||
} else if (event.delta.type === 'input_json_delta') {
|
||||
// 工具调用参数的增量更新
|
||||
baseChunk.choices[0].delta.tool_calls = [{
|
||||
index: event.index || 0,
|
||||
function: {
|
||||
arguments: event.delta.partial_json || ''
|
||||
baseChunk.choices[0].delta.tool_calls = [
|
||||
{
|
||||
index: event.index || 0,
|
||||
function: {
|
||||
arguments: event.delta.partial_json || ''
|
||||
}
|
||||
}
|
||||
}];
|
||||
]
|
||||
}
|
||||
} else if (event.type === 'message_delta' && event.delta) {
|
||||
if (event.delta.stop_reason) {
|
||||
baseChunk.choices[0].finish_reason = this._mapStopReason(event.delta.stop_reason);
|
||||
baseChunk.choices[0].finish_reason = this._mapStopReason(event.delta.stop_reason)
|
||||
}
|
||||
if (event.usage) {
|
||||
baseChunk.usage = this._convertUsage(event.usage);
|
||||
baseChunk.usage = this._convertUsage(event.usage)
|
||||
}
|
||||
} else if (event.type === 'message_stop') {
|
||||
// message_stop 事件不需要返回 chunk,[DONE] 标记会在 convertStreamChunk 中添加
|
||||
return null;
|
||||
return null
|
||||
} else {
|
||||
// 忽略其他类型的事件
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
return baseChunk;
|
||||
return baseChunk
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成随机 ID
|
||||
*/
|
||||
_generateId() {
|
||||
return Math.random().toString(36).substring(2, 15) +
|
||||
Math.random().toString(36).substring(2, 15);
|
||||
return Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15)
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new OpenAIToClaudeConverter();
|
||||
module.exports = new OpenAIToClaudeConverter()
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const https = require('https');
|
||||
const logger = require('../utils/logger');
|
||||
const fs = require('fs')
|
||||
const path = require('path')
|
||||
const https = require('https')
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class PricingService {
|
||||
constructor() {
|
||||
this.dataDir = path.join(process.cwd(), 'data');
|
||||
this.pricingFile = path.join(this.dataDir, 'model_pricing.json');
|
||||
this.pricingUrl = 'https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json';
|
||||
this.fallbackFile = path.join(process.cwd(), 'resources', 'model-pricing', 'model_prices_and_context_window.json');
|
||||
this.pricingData = null;
|
||||
this.lastUpdated = null;
|
||||
this.updateInterval = 24 * 60 * 60 * 1000; // 24小时
|
||||
this.fileWatcher = null; // 文件监听器
|
||||
this.reloadDebounceTimer = null; // 防抖定时器
|
||||
this.dataDir = path.join(process.cwd(), 'data')
|
||||
this.pricingFile = path.join(this.dataDir, 'model_pricing.json')
|
||||
this.pricingUrl =
|
||||
'https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json'
|
||||
this.fallbackFile = path.join(
|
||||
process.cwd(),
|
||||
'resources',
|
||||
'model-pricing',
|
||||
'model_prices_and_context_window.json'
|
||||
)
|
||||
this.pricingData = null
|
||||
this.lastUpdated = null
|
||||
this.updateInterval = 24 * 60 * 60 * 1000 // 24小时
|
||||
this.fileWatcher = null // 文件监听器
|
||||
this.reloadDebounceTimer = null // 防抖定时器
|
||||
}
|
||||
|
||||
// 初始化价格服务
|
||||
@@ -21,72 +27,74 @@ class PricingService {
|
||||
try {
|
||||
// 确保data目录存在
|
||||
if (!fs.existsSync(this.dataDir)) {
|
||||
fs.mkdirSync(this.dataDir, { recursive: true });
|
||||
logger.info('📁 Created data directory');
|
||||
fs.mkdirSync(this.dataDir, { recursive: true })
|
||||
logger.info('📁 Created data directory')
|
||||
}
|
||||
|
||||
// 检查是否需要下载或更新价格数据
|
||||
await this.checkAndUpdatePricing();
|
||||
|
||||
await this.checkAndUpdatePricing()
|
||||
|
||||
// 设置定时更新
|
||||
setInterval(() => {
|
||||
this.checkAndUpdatePricing();
|
||||
}, this.updateInterval);
|
||||
this.checkAndUpdatePricing()
|
||||
}, this.updateInterval)
|
||||
|
||||
// 设置文件监听器
|
||||
this.setupFileWatcher();
|
||||
this.setupFileWatcher()
|
||||
|
||||
logger.success('💰 Pricing service initialized successfully');
|
||||
logger.success('💰 Pricing service initialized successfully')
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to initialize pricing service:', error);
|
||||
logger.error('❌ Failed to initialize pricing service:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查并更新价格数据
|
||||
async checkAndUpdatePricing() {
|
||||
try {
|
||||
const needsUpdate = this.needsUpdate();
|
||||
|
||||
const needsUpdate = this.needsUpdate()
|
||||
|
||||
if (needsUpdate) {
|
||||
logger.info('🔄 Updating model pricing data...');
|
||||
await this.downloadPricingData();
|
||||
logger.info('🔄 Updating model pricing data...')
|
||||
await this.downloadPricingData()
|
||||
} else {
|
||||
// 如果不需要更新,加载现有数据
|
||||
await this.loadPricingData();
|
||||
await this.loadPricingData()
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to check/update pricing:', error);
|
||||
logger.error('❌ Failed to check/update pricing:', error)
|
||||
// 如果更新失败,尝试使用fallback
|
||||
await this.useFallbackPricing();
|
||||
await this.useFallbackPricing()
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要更新
|
||||
needsUpdate() {
|
||||
if (!fs.existsSync(this.pricingFile)) {
|
||||
logger.info('📋 Pricing file not found, will download');
|
||||
return true;
|
||||
logger.info('📋 Pricing file not found, will download')
|
||||
return true
|
||||
}
|
||||
|
||||
const stats = fs.statSync(this.pricingFile);
|
||||
const fileAge = Date.now() - stats.mtime.getTime();
|
||||
|
||||
const stats = fs.statSync(this.pricingFile)
|
||||
const fileAge = Date.now() - stats.mtime.getTime()
|
||||
|
||||
if (fileAge > this.updateInterval) {
|
||||
logger.info(`📋 Pricing file is ${Math.round(fileAge / (60 * 60 * 1000))} hours old, will update`);
|
||||
return true;
|
||||
logger.info(
|
||||
`📋 Pricing file is ${Math.round(fileAge / (60 * 60 * 1000))} hours old, will update`
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
return false;
|
||||
return false
|
||||
}
|
||||
|
||||
// 下载价格数据
|
||||
async downloadPricingData() {
|
||||
try {
|
||||
await this._downloadFromRemote();
|
||||
await this._downloadFromRemote()
|
||||
} catch (downloadError) {
|
||||
logger.warn(`⚠️ Failed to download pricing data: ${downloadError.message}`);
|
||||
logger.info('📋 Using local fallback pricing data...');
|
||||
await this.useFallbackPricing();
|
||||
logger.warn(`⚠️ Failed to download pricing data: ${downloadError.message}`)
|
||||
logger.info('📋 Using local fallback pricing data...')
|
||||
await this.useFallbackPricing()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,67 +103,69 @@ class PricingService {
|
||||
return new Promise((resolve, reject) => {
|
||||
const request = https.get(this.pricingUrl, (response) => {
|
||||
if (response.statusCode !== 200) {
|
||||
reject(new Error(`HTTP ${response.statusCode}: ${response.statusMessage}`));
|
||||
return;
|
||||
reject(new Error(`HTTP ${response.statusCode}: ${response.statusMessage}`))
|
||||
return
|
||||
}
|
||||
|
||||
let data = '';
|
||||
let data = ''
|
||||
response.on('data', (chunk) => {
|
||||
data += chunk;
|
||||
});
|
||||
data += chunk
|
||||
})
|
||||
|
||||
response.on('end', () => {
|
||||
try {
|
||||
const jsonData = JSON.parse(data);
|
||||
|
||||
const jsonData = JSON.parse(data)
|
||||
|
||||
// 保存到文件
|
||||
fs.writeFileSync(this.pricingFile, JSON.stringify(jsonData, null, 2));
|
||||
|
||||
fs.writeFileSync(this.pricingFile, JSON.stringify(jsonData, null, 2))
|
||||
|
||||
// 更新内存中的数据
|
||||
this.pricingData = jsonData;
|
||||
this.lastUpdated = new Date();
|
||||
|
||||
logger.success(`💰 Downloaded pricing data for ${Object.keys(jsonData).length} models`);
|
||||
|
||||
this.pricingData = jsonData
|
||||
this.lastUpdated = new Date()
|
||||
|
||||
logger.success(`💰 Downloaded pricing data for ${Object.keys(jsonData).length} models`)
|
||||
|
||||
// 设置或重新设置文件监听器
|
||||
this.setupFileWatcher();
|
||||
|
||||
resolve();
|
||||
this.setupFileWatcher()
|
||||
|
||||
resolve()
|
||||
} catch (error) {
|
||||
reject(new Error(`Failed to parse pricing data: ${error.message}`));
|
||||
reject(new Error(`Failed to parse pricing data: ${error.message}`))
|
||||
}
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
|
||||
request.on('error', (error) => {
|
||||
reject(new Error(`Network error: ${error.message}`));
|
||||
});
|
||||
reject(new Error(`Network error: ${error.message}`))
|
||||
})
|
||||
|
||||
request.setTimeout(30000, () => {
|
||||
request.destroy();
|
||||
reject(new Error('Download timeout after 30 seconds'));
|
||||
});
|
||||
});
|
||||
request.destroy()
|
||||
reject(new Error('Download timeout after 30 seconds'))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 加载本地价格数据
|
||||
async loadPricingData() {
|
||||
try {
|
||||
if (fs.existsSync(this.pricingFile)) {
|
||||
const data = fs.readFileSync(this.pricingFile, 'utf8');
|
||||
this.pricingData = JSON.parse(data);
|
||||
|
||||
const stats = fs.statSync(this.pricingFile);
|
||||
this.lastUpdated = stats.mtime;
|
||||
|
||||
logger.info(`💰 Loaded pricing data for ${Object.keys(this.pricingData).length} models from cache`);
|
||||
const data = fs.readFileSync(this.pricingFile, 'utf8')
|
||||
this.pricingData = JSON.parse(data)
|
||||
|
||||
const stats = fs.statSync(this.pricingFile)
|
||||
this.lastUpdated = stats.mtime
|
||||
|
||||
logger.info(
|
||||
`💰 Loaded pricing data for ${Object.keys(this.pricingData).length} models from cache`
|
||||
)
|
||||
} else {
|
||||
logger.warn('💰 No pricing data file found, will use fallback');
|
||||
await this.useFallbackPricing();
|
||||
logger.warn('💰 No pricing data file found, will use fallback')
|
||||
await this.useFallbackPricing()
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to load pricing data:', error);
|
||||
await this.useFallbackPricing();
|
||||
logger.error('❌ Failed to load pricing data:', error)
|
||||
await this.useFallbackPricing()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,89 +173,95 @@ class PricingService {
|
||||
async useFallbackPricing() {
|
||||
try {
|
||||
if (fs.existsSync(this.fallbackFile)) {
|
||||
logger.info('📋 Copying fallback pricing data to data directory...');
|
||||
|
||||
logger.info('📋 Copying fallback pricing data to data directory...')
|
||||
|
||||
// 读取fallback文件
|
||||
const fallbackData = fs.readFileSync(this.fallbackFile, 'utf8');
|
||||
const jsonData = JSON.parse(fallbackData);
|
||||
|
||||
const fallbackData = fs.readFileSync(this.fallbackFile, 'utf8')
|
||||
const jsonData = JSON.parse(fallbackData)
|
||||
|
||||
// 保存到data目录
|
||||
fs.writeFileSync(this.pricingFile, JSON.stringify(jsonData, null, 2));
|
||||
|
||||
fs.writeFileSync(this.pricingFile, JSON.stringify(jsonData, null, 2))
|
||||
|
||||
// 更新内存中的数据
|
||||
this.pricingData = jsonData;
|
||||
this.lastUpdated = new Date();
|
||||
|
||||
this.pricingData = jsonData
|
||||
this.lastUpdated = new Date()
|
||||
|
||||
// 设置或重新设置文件监听器
|
||||
this.setupFileWatcher();
|
||||
|
||||
logger.warn(`⚠️ Using fallback pricing data for ${Object.keys(jsonData).length} models`);
|
||||
logger.info('💡 Note: This fallback data may be outdated. The system will try to update from the remote source on next check.');
|
||||
this.setupFileWatcher()
|
||||
|
||||
logger.warn(`⚠️ Using fallback pricing data for ${Object.keys(jsonData).length} models`)
|
||||
logger.info(
|
||||
'💡 Note: This fallback data may be outdated. The system will try to update from the remote source on next check.'
|
||||
)
|
||||
} else {
|
||||
logger.error('❌ Fallback pricing file not found at:', this.fallbackFile);
|
||||
logger.error('❌ Please ensure the resources/model-pricing directory exists with the pricing file');
|
||||
this.pricingData = {};
|
||||
logger.error('❌ Fallback pricing file not found at:', this.fallbackFile)
|
||||
logger.error(
|
||||
'❌ Please ensure the resources/model-pricing directory exists with the pricing file'
|
||||
)
|
||||
this.pricingData = {}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to use fallback pricing data:', error);
|
||||
this.pricingData = {};
|
||||
logger.error('❌ Failed to use fallback pricing data:', error)
|
||||
this.pricingData = {}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取模型价格信息
|
||||
getModelPricing(modelName) {
|
||||
if (!this.pricingData || !modelName) {
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
// 尝试直接匹配
|
||||
if (this.pricingData[modelName]) {
|
||||
return this.pricingData[modelName];
|
||||
return this.pricingData[modelName]
|
||||
}
|
||||
|
||||
// 对于Bedrock区域前缀模型(如 us.anthropic.claude-sonnet-4-20250514-v1:0),
|
||||
// 尝试去掉区域前缀进行匹配
|
||||
if (modelName.includes('.anthropic.') || modelName.includes('.claude')) {
|
||||
// 提取不带区域前缀的模型名
|
||||
const withoutRegion = modelName.replace(/^(us|eu|apac)\./, '');
|
||||
const withoutRegion = modelName.replace(/^(us|eu|apac)\./, '')
|
||||
if (this.pricingData[withoutRegion]) {
|
||||
logger.debug(`💰 Found pricing for ${modelName} by removing region prefix: ${withoutRegion}`);
|
||||
return this.pricingData[withoutRegion];
|
||||
logger.debug(
|
||||
`💰 Found pricing for ${modelName} by removing region prefix: ${withoutRegion}`
|
||||
)
|
||||
return this.pricingData[withoutRegion]
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试模糊匹配(处理版本号等变化)
|
||||
const normalizedModel = modelName.toLowerCase().replace(/[_-]/g, '');
|
||||
|
||||
const normalizedModel = modelName.toLowerCase().replace(/[_-]/g, '')
|
||||
|
||||
for (const [key, value] of Object.entries(this.pricingData)) {
|
||||
const normalizedKey = key.toLowerCase().replace(/[_-]/g, '');
|
||||
const normalizedKey = key.toLowerCase().replace(/[_-]/g, '')
|
||||
if (normalizedKey.includes(normalizedModel) || normalizedModel.includes(normalizedKey)) {
|
||||
logger.debug(`💰 Found pricing for ${modelName} using fuzzy match: ${key}`);
|
||||
return value;
|
||||
logger.debug(`💰 Found pricing for ${modelName} using fuzzy match: ${key}`)
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// 对于Bedrock模型,尝试更智能的匹配
|
||||
if (modelName.includes('anthropic.claude')) {
|
||||
// 提取核心模型名部分(去掉区域和前缀)
|
||||
const coreModel = modelName.replace(/^(us|eu|apac)\./, '').replace('anthropic.', '');
|
||||
|
||||
const coreModel = modelName.replace(/^(us|eu|apac)\./, '').replace('anthropic.', '')
|
||||
|
||||
for (const [key, value] of Object.entries(this.pricingData)) {
|
||||
if (key.includes(coreModel) || key.replace('anthropic.', '').includes(coreModel)) {
|
||||
logger.debug(`💰 Found pricing for ${modelName} using Bedrock core model match: ${key}`);
|
||||
return value;
|
||||
logger.debug(`💰 Found pricing for ${modelName} using Bedrock core model match: ${key}`)
|
||||
return value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(`💰 No pricing found for model: ${modelName}`);
|
||||
return null;
|
||||
logger.debug(`💰 No pricing found for model: ${modelName}`)
|
||||
return null
|
||||
}
|
||||
|
||||
// 计算使用费用
|
||||
calculateCost(usage, modelName) {
|
||||
const pricing = this.getModelPricing(modelName);
|
||||
|
||||
const pricing = this.getModelPricing(modelName)
|
||||
|
||||
if (!pricing) {
|
||||
return {
|
||||
inputCost: 0,
|
||||
@@ -254,13 +270,15 @@ class PricingService {
|
||||
cacheReadCost: 0,
|
||||
totalCost: 0,
|
||||
hasPricing: false
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const inputCost = (usage.input_tokens || 0) * (pricing.input_cost_per_token || 0);
|
||||
const outputCost = (usage.output_tokens || 0) * (pricing.output_cost_per_token || 0);
|
||||
const cacheCreateCost = (usage.cache_creation_input_tokens || 0) * (pricing.cache_creation_input_token_cost || 0);
|
||||
const cacheReadCost = (usage.cache_read_input_tokens || 0) * (pricing.cache_read_input_token_cost || 0);
|
||||
const inputCost = (usage.input_tokens || 0) * (pricing.input_cost_per_token || 0)
|
||||
const outputCost = (usage.output_tokens || 0) * (pricing.output_cost_per_token || 0)
|
||||
const cacheCreateCost =
|
||||
(usage.cache_creation_input_tokens || 0) * (pricing.cache_creation_input_token_cost || 0)
|
||||
const cacheReadCost =
|
||||
(usage.cache_read_input_tokens || 0) * (pricing.cache_read_input_token_cost || 0)
|
||||
|
||||
return {
|
||||
inputCost,
|
||||
@@ -275,16 +293,24 @@ class PricingService {
|
||||
cacheCreate: pricing.cache_creation_input_token_cost || 0,
|
||||
cacheRead: pricing.cache_read_input_token_cost || 0
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化价格显示
|
||||
formatCost(cost) {
|
||||
if (cost === 0) return '$0.000000';
|
||||
if (cost < 0.000001) return `$${cost.toExponential(2)}`;
|
||||
if (cost < 0.01) return `$${cost.toFixed(6)}`;
|
||||
if (cost < 1) return `$${cost.toFixed(4)}`;
|
||||
return `$${cost.toFixed(2)}`;
|
||||
if (cost === 0) {
|
||||
return '$0.000000'
|
||||
}
|
||||
if (cost < 0.000001) {
|
||||
return `$${cost.toExponential(2)}`
|
||||
}
|
||||
if (cost < 0.01) {
|
||||
return `$${cost.toFixed(6)}`
|
||||
}
|
||||
if (cost < 1) {
|
||||
return `$${cost.toFixed(4)}`
|
||||
}
|
||||
return `$${cost.toFixed(2)}`
|
||||
}
|
||||
|
||||
// 获取服务状态
|
||||
@@ -293,23 +319,25 @@ class PricingService {
|
||||
initialized: this.pricingData !== null,
|
||||
lastUpdated: this.lastUpdated,
|
||||
modelCount: this.pricingData ? Object.keys(this.pricingData).length : 0,
|
||||
nextUpdate: this.lastUpdated ? new Date(this.lastUpdated.getTime() + this.updateInterval) : null
|
||||
};
|
||||
nextUpdate: this.lastUpdated
|
||||
? new Date(this.lastUpdated.getTime() + this.updateInterval)
|
||||
: null
|
||||
}
|
||||
}
|
||||
|
||||
// 强制更新价格数据
|
||||
async forceUpdate() {
|
||||
try {
|
||||
await this._downloadFromRemote();
|
||||
return { success: true, message: 'Pricing data updated successfully' };
|
||||
await this._downloadFromRemote()
|
||||
return { success: true, message: 'Pricing data updated successfully' }
|
||||
} catch (error) {
|
||||
logger.error('❌ Force update failed:', error);
|
||||
logger.info('📋 Force update failed, using fallback pricing data...');
|
||||
await this.useFallbackPricing();
|
||||
return {
|
||||
success: false,
|
||||
message: `Download failed: ${error.message}. Using fallback pricing data instead.`
|
||||
};
|
||||
logger.error('❌ Force update failed:', error)
|
||||
logger.info('📋 Force update failed, using fallback pricing data...')
|
||||
await this.useFallbackPricing()
|
||||
return {
|
||||
success: false,
|
||||
message: `Download failed: ${error.message}. Using fallback pricing data instead.`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,43 +346,45 @@ class PricingService {
|
||||
try {
|
||||
// 如果已有监听器,先关闭
|
||||
if (this.fileWatcher) {
|
||||
this.fileWatcher.close();
|
||||
this.fileWatcher = null;
|
||||
this.fileWatcher.close()
|
||||
this.fileWatcher = null
|
||||
}
|
||||
|
||||
// 只有文件存在时才设置监听器
|
||||
if (!fs.existsSync(this.pricingFile)) {
|
||||
logger.debug('💰 Pricing file does not exist yet, skipping file watcher setup');
|
||||
return;
|
||||
logger.debug('💰 Pricing file does not exist yet, skipping file watcher setup')
|
||||
return
|
||||
}
|
||||
|
||||
// 使用 fs.watchFile 作为更可靠的文件监听方式
|
||||
// 它使用轮询,虽然性能稍差,但更可靠
|
||||
const watchOptions = {
|
||||
persistent: true,
|
||||
const watchOptions = {
|
||||
persistent: true,
|
||||
interval: 60000 // 每60秒检查一次
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 记录初始的修改时间
|
||||
let lastMtime = fs.statSync(this.pricingFile).mtimeMs;
|
||||
|
||||
let lastMtime = fs.statSync(this.pricingFile).mtimeMs
|
||||
|
||||
fs.watchFile(this.pricingFile, watchOptions, (curr, _prev) => {
|
||||
// 检查文件是否真的被修改了(不仅仅是访问)
|
||||
if (curr.mtimeMs !== lastMtime) {
|
||||
lastMtime = curr.mtimeMs;
|
||||
logger.debug(`💰 Detected change in pricing file (mtime: ${new Date(curr.mtime).toISOString()})`);
|
||||
this.handleFileChange();
|
||||
lastMtime = curr.mtimeMs
|
||||
logger.debug(
|
||||
`💰 Detected change in pricing file (mtime: ${new Date(curr.mtime).toISOString()})`
|
||||
)
|
||||
this.handleFileChange()
|
||||
}
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
// 保存引用以便清理
|
||||
this.fileWatcher = {
|
||||
close: () => fs.unwatchFile(this.pricingFile)
|
||||
};
|
||||
}
|
||||
|
||||
logger.info('👁️ File watcher set up for model_pricing.json (polling every 60s)');
|
||||
logger.info('👁️ File watcher set up for model_pricing.json (polling every 60s)')
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to setup file watcher:', error);
|
||||
logger.error('❌ Failed to setup file watcher:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,14 +392,14 @@ class PricingService {
|
||||
handleFileChange() {
|
||||
// 清除之前的定时器
|
||||
if (this.reloadDebounceTimer) {
|
||||
clearTimeout(this.reloadDebounceTimer);
|
||||
clearTimeout(this.reloadDebounceTimer)
|
||||
}
|
||||
|
||||
// 设置新的定时器(防抖500ms)
|
||||
this.reloadDebounceTimer = setTimeout(async () => {
|
||||
logger.info('🔄 Reloading pricing data due to file change...');
|
||||
await this.reloadPricingData();
|
||||
}, 500);
|
||||
logger.info('🔄 Reloading pricing data due to file change...')
|
||||
await this.reloadPricingData()
|
||||
}, 500)
|
||||
}
|
||||
|
||||
// 重新加载价格数据
|
||||
@@ -377,55 +407,57 @@ class PricingService {
|
||||
try {
|
||||
// 验证文件是否存在
|
||||
if (!fs.existsSync(this.pricingFile)) {
|
||||
logger.warn('💰 Pricing file was deleted, using fallback');
|
||||
await this.useFallbackPricing();
|
||||
logger.warn('💰 Pricing file was deleted, using fallback')
|
||||
await this.useFallbackPricing()
|
||||
// 重新设置文件监听器(fallback会创建新文件)
|
||||
this.setupFileWatcher();
|
||||
return;
|
||||
this.setupFileWatcher()
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
const data = fs.readFileSync(this.pricingFile, 'utf8');
|
||||
|
||||
const data = fs.readFileSync(this.pricingFile, 'utf8')
|
||||
|
||||
// 尝试解析JSON
|
||||
const jsonData = JSON.parse(data);
|
||||
|
||||
const jsonData = JSON.parse(data)
|
||||
|
||||
// 验证数据结构
|
||||
if (typeof jsonData !== 'object' || Object.keys(jsonData).length === 0) {
|
||||
throw new Error('Invalid pricing data structure');
|
||||
throw new Error('Invalid pricing data structure')
|
||||
}
|
||||
|
||||
// 更新内存中的数据
|
||||
this.pricingData = jsonData;
|
||||
this.lastUpdated = new Date();
|
||||
|
||||
const modelCount = Object.keys(jsonData).length;
|
||||
logger.success(`💰 Reloaded pricing data for ${modelCount} models from file`);
|
||||
|
||||
this.pricingData = jsonData
|
||||
this.lastUpdated = new Date()
|
||||
|
||||
const modelCount = Object.keys(jsonData).length
|
||||
logger.success(`💰 Reloaded pricing data for ${modelCount} models from file`)
|
||||
|
||||
// 显示一些统计信息
|
||||
const claudeModels = Object.keys(jsonData).filter(k => k.includes('claude')).length;
|
||||
const gptModels = Object.keys(jsonData).filter(k => k.includes('gpt')).length;
|
||||
const geminiModels = Object.keys(jsonData).filter(k => k.includes('gemini')).length;
|
||||
|
||||
logger.debug(`💰 Model breakdown: Claude=${claudeModels}, GPT=${gptModels}, Gemini=${geminiModels}`);
|
||||
const claudeModels = Object.keys(jsonData).filter((k) => k.includes('claude')).length
|
||||
const gptModels = Object.keys(jsonData).filter((k) => k.includes('gpt')).length
|
||||
const geminiModels = Object.keys(jsonData).filter((k) => k.includes('gemini')).length
|
||||
|
||||
logger.debug(
|
||||
`💰 Model breakdown: Claude=${claudeModels}, GPT=${gptModels}, Gemini=${geminiModels}`
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to reload pricing data:', error);
|
||||
logger.warn('💰 Keeping existing pricing data in memory');
|
||||
logger.error('❌ Failed to reload pricing data:', error)
|
||||
logger.warn('💰 Keeping existing pricing data in memory')
|
||||
}
|
||||
}
|
||||
|
||||
// 清理资源
|
||||
cleanup() {
|
||||
if (this.fileWatcher) {
|
||||
this.fileWatcher.close();
|
||||
this.fileWatcher = null;
|
||||
logger.debug('💰 File watcher closed');
|
||||
this.fileWatcher.close()
|
||||
this.fileWatcher = null
|
||||
logger.debug('💰 File watcher closed')
|
||||
}
|
||||
if (this.reloadDebounceTimer) {
|
||||
clearTimeout(this.reloadDebounceTimer);
|
||||
this.reloadDebounceTimer = null;
|
||||
clearTimeout(this.reloadDebounceTimer)
|
||||
this.reloadDebounceTimer = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new PricingService();
|
||||
module.exports = new PricingService()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
const { v4: uuidv4 } = require('uuid')
|
||||
|
||||
/**
|
||||
* Token 刷新锁服务
|
||||
@@ -8,30 +8,29 @@ const { v4: uuidv4 } = require('uuid');
|
||||
*/
|
||||
class TokenRefreshService {
|
||||
constructor() {
|
||||
this.lockTTL = 60; // 锁的TTL: 60秒(token刷新通常在30秒内完成)
|
||||
this.lockValue = new Map(); // 存储每个锁的唯一值
|
||||
this.lockTTL = 60 // 锁的TTL: 60秒(token刷新通常在30秒内完成)
|
||||
this.lockValue = new Map() // 存储每个锁的唯一值
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取分布式锁
|
||||
* 使用唯一标识符作为值,避免误释放其他进程的锁
|
||||
*/
|
||||
async acquireLock(lockKey) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const lockId = uuidv4();
|
||||
const result = await client.set(lockKey, lockId, 'NX', 'EX', this.lockTTL);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const lockId = uuidv4()
|
||||
const result = await client.set(lockKey, lockId, 'NX', 'EX', this.lockTTL)
|
||||
|
||||
if (result === 'OK') {
|
||||
this.lockValue.set(lockKey, lockId);
|
||||
logger.debug(`🔒 Acquired lock ${lockKey} with ID ${lockId}, TTL: ${this.lockTTL}s`);
|
||||
return true;
|
||||
this.lockValue.set(lockKey, lockId)
|
||||
logger.debug(`🔒 Acquired lock ${lockKey} with ID ${lockId}, TTL: ${this.lockTTL}s`)
|
||||
return true
|
||||
}
|
||||
return false;
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error(`Failed to acquire lock ${lockKey}:`, error);
|
||||
return false;
|
||||
logger.error(`Failed to acquire lock ${lockKey}:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,12 +40,12 @@ class TokenRefreshService {
|
||||
*/
|
||||
async releaseLock(lockKey) {
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const lockId = this.lockValue.get(lockKey);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const lockId = this.lockValue.get(lockKey)
|
||||
|
||||
if (!lockId) {
|
||||
logger.warn(`⚠️ No lock ID found for ${lockKey}, skipping release`);
|
||||
return;
|
||||
logger.warn(`⚠️ No lock ID found for ${lockKey}, skipping release`)
|
||||
return
|
||||
}
|
||||
|
||||
// Lua 脚本:只有当值匹配时才删除
|
||||
@@ -56,18 +55,18 @@ class TokenRefreshService {
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`;
|
||||
|
||||
const result = await client.eval(luaScript, 1, lockKey, lockId);
|
||||
|
||||
`
|
||||
|
||||
const result = await client.eval(luaScript, 1, lockKey, lockId)
|
||||
|
||||
if (result === 1) {
|
||||
this.lockValue.delete(lockKey);
|
||||
logger.debug(`🔓 Released lock ${lockKey} with ID ${lockId}`);
|
||||
this.lockValue.delete(lockKey)
|
||||
logger.debug(`🔓 Released lock ${lockKey} with ID ${lockId}`)
|
||||
} else {
|
||||
logger.warn(`⚠️ Lock ${lockKey} was not released - value mismatch or already expired`);
|
||||
logger.warn(`⚠️ Lock ${lockKey} was not released - value mismatch or already expired`)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to release lock ${lockKey}:`, error);
|
||||
logger.error(`Failed to release lock ${lockKey}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,8 +77,8 @@ class TokenRefreshService {
|
||||
* @returns {Promise<boolean>} 是否成功获取锁
|
||||
*/
|
||||
async acquireRefreshLock(accountId, platform = 'claude') {
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`;
|
||||
return await this.acquireLock(lockKey);
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`
|
||||
return await this.acquireLock(lockKey)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -88,8 +87,8 @@ class TokenRefreshService {
|
||||
* @param {string} platform - 平台类型 (claude/gemini)
|
||||
*/
|
||||
async releaseRefreshLock(accountId, platform = 'claude') {
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`;
|
||||
await this.releaseLock(lockKey);
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`
|
||||
await this.releaseLock(lockKey)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -99,14 +98,14 @@ class TokenRefreshService {
|
||||
* @returns {Promise<boolean>} 锁是否存在
|
||||
*/
|
||||
async isRefreshLocked(accountId, platform = 'claude') {
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`;
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const exists = await client.exists(lockKey);
|
||||
return exists === 1;
|
||||
const client = redis.getClientSafe()
|
||||
const exists = await client.exists(lockKey)
|
||||
return exists === 1
|
||||
} catch (error) {
|
||||
logger.error(`Failed to check lock status ${lockKey}:`, error);
|
||||
return false;
|
||||
logger.error(`Failed to check lock status ${lockKey}:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,14 +116,14 @@ class TokenRefreshService {
|
||||
* @returns {Promise<number>} 剩余秒数,-1表示锁不存在
|
||||
*/
|
||||
async getLockTTL(accountId, platform = 'claude') {
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`;
|
||||
const lockKey = `token_refresh_lock:${platform}:${accountId}`
|
||||
try {
|
||||
const client = redis.getClientSafe();
|
||||
const ttl = await client.ttl(lockKey);
|
||||
return ttl;
|
||||
const client = redis.getClientSafe()
|
||||
const ttl = await client.ttl(lockKey)
|
||||
return ttl
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get lock TTL ${lockKey}:`, error);
|
||||
return -1;
|
||||
logger.error(`Failed to get lock TTL ${lockKey}:`, error)
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,12 +132,12 @@ class TokenRefreshService {
|
||||
* 在进程退出时调用,避免内存泄漏
|
||||
*/
|
||||
cleanup() {
|
||||
this.lockValue.clear();
|
||||
logger.info('🧹 Cleaned up local lock records');
|
||||
this.lockValue.clear()
|
||||
logger.info('🧹 Cleaned up local lock records')
|
||||
}
|
||||
}
|
||||
|
||||
// 创建单例实例
|
||||
const tokenRefreshService = new TokenRefreshService();
|
||||
const tokenRefreshService = new TokenRefreshService()
|
||||
|
||||
module.exports = tokenRefreshService;
|
||||
module.exports = tokenRefreshService
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
const claudeAccountService = require('./claudeAccountService');
|
||||
const claudeConsoleAccountService = require('./claudeConsoleAccountService');
|
||||
const bedrockAccountService = require('./bedrockAccountService');
|
||||
const accountGroupService = require('./accountGroupService');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const claudeAccountService = require('./claudeAccountService')
|
||||
const claudeConsoleAccountService = require('./claudeConsoleAccountService')
|
||||
const bedrockAccountService = require('./bedrockAccountService')
|
||||
const accountGroupService = require('./accountGroupService')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class UnifiedClaudeScheduler {
|
||||
constructor() {
|
||||
this.SESSION_MAPPING_PREFIX = 'unified_claude_session_mapping:';
|
||||
this.SESSION_MAPPING_PREFIX = 'unified_claude_session_mapping:'
|
||||
}
|
||||
|
||||
// 🔧 辅助方法:检查账户是否可调度(兼容字符串和布尔值)
|
||||
_isSchedulable(schedulable) {
|
||||
// 如果是 undefined 或 null,默认为可调度
|
||||
if (schedulable === undefined || schedulable === null) {
|
||||
return true;
|
||||
return true
|
||||
}
|
||||
// 明确设置为 false(布尔值)或 'false'(字符串)时不可调度
|
||||
return schedulable !== false && schedulable !== 'false';
|
||||
return schedulable !== false && schedulable !== 'false'
|
||||
}
|
||||
|
||||
// 🎯 统一调度Claude账号(官方和Console)
|
||||
@@ -27,177 +27,248 @@ class UnifiedClaudeScheduler {
|
||||
if (apiKeyData.claudeAccountId) {
|
||||
// 检查是否是分组
|
||||
if (apiKeyData.claudeAccountId.startsWith('group:')) {
|
||||
const groupId = apiKeyData.claudeAccountId.replace('group:', '');
|
||||
logger.info(`🎯 API key ${apiKeyData.name} is bound to group ${groupId}, selecting from group`);
|
||||
return await this.selectAccountFromGroup(groupId, sessionHash, requestedModel);
|
||||
const groupId = apiKeyData.claudeAccountId.replace('group:', '')
|
||||
logger.info(
|
||||
`🎯 API key ${apiKeyData.name} is bound to group ${groupId}, selecting from group`
|
||||
)
|
||||
return await this.selectAccountFromGroup(groupId, sessionHash, requestedModel)
|
||||
}
|
||||
|
||||
|
||||
// 普通专属账户
|
||||
const boundAccount = await redis.getClaudeAccount(apiKeyData.claudeAccountId);
|
||||
const boundAccount = await redis.getClaudeAccount(apiKeyData.claudeAccountId)
|
||||
if (boundAccount && boundAccount.isActive === 'true' && boundAccount.status !== 'error') {
|
||||
logger.info(`🎯 Using bound dedicated Claude OAuth account: ${boundAccount.name} (${apiKeyData.claudeAccountId}) for API key ${apiKeyData.name}`);
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Claude OAuth account: ${boundAccount.name} (${apiKeyData.claudeAccountId}) for API key ${apiKeyData.name}`
|
||||
)
|
||||
return {
|
||||
accountId: apiKeyData.claudeAccountId,
|
||||
accountType: 'claude-official'
|
||||
};
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Claude OAuth account ${apiKeyData.claudeAccountId} is not available, falling back to pool`);
|
||||
logger.warn(
|
||||
`⚠️ Bound Claude OAuth account ${apiKeyData.claudeAccountId} is not available, falling back to pool`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 2. 检查Claude Console账户绑定
|
||||
if (apiKeyData.claudeConsoleAccountId) {
|
||||
const boundConsoleAccount = await claudeConsoleAccountService.getAccount(apiKeyData.claudeConsoleAccountId);
|
||||
if (boundConsoleAccount && boundConsoleAccount.isActive === true && boundConsoleAccount.status === 'active') {
|
||||
logger.info(`🎯 Using bound dedicated Claude Console account: ${boundConsoleAccount.name} (${apiKeyData.claudeConsoleAccountId}) for API key ${apiKeyData.name}`);
|
||||
const boundConsoleAccount = await claudeConsoleAccountService.getAccount(
|
||||
apiKeyData.claudeConsoleAccountId
|
||||
)
|
||||
if (
|
||||
boundConsoleAccount &&
|
||||
boundConsoleAccount.isActive === true &&
|
||||
boundConsoleAccount.status === 'active'
|
||||
) {
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Claude Console account: ${boundConsoleAccount.name} (${apiKeyData.claudeConsoleAccountId}) for API key ${apiKeyData.name}`
|
||||
)
|
||||
return {
|
||||
accountId: apiKeyData.claudeConsoleAccountId,
|
||||
accountType: 'claude-console'
|
||||
};
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Claude Console account ${apiKeyData.claudeConsoleAccountId} is not available, falling back to pool`);
|
||||
logger.warn(
|
||||
`⚠️ Bound Claude Console account ${apiKeyData.claudeConsoleAccountId} is not available, falling back to pool`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 检查Bedrock账户绑定
|
||||
if (apiKeyData.bedrockAccountId) {
|
||||
const boundBedrockAccountResult = await bedrockAccountService.getAccount(apiKeyData.bedrockAccountId);
|
||||
const boundBedrockAccountResult = await bedrockAccountService.getAccount(
|
||||
apiKeyData.bedrockAccountId
|
||||
)
|
||||
if (boundBedrockAccountResult.success && boundBedrockAccountResult.data.isActive === true) {
|
||||
logger.info(`🎯 Using bound dedicated Bedrock account: ${boundBedrockAccountResult.data.name} (${apiKeyData.bedrockAccountId}) for API key ${apiKeyData.name}`);
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Bedrock account: ${boundBedrockAccountResult.data.name} (${apiKeyData.bedrockAccountId}) for API key ${apiKeyData.name}`
|
||||
)
|
||||
return {
|
||||
accountId: apiKeyData.bedrockAccountId,
|
||||
accountType: 'bedrock'
|
||||
};
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Bedrock account ${apiKeyData.bedrockAccountId} is not available, falling back to pool`);
|
||||
logger.warn(
|
||||
`⚠️ Bound Bedrock account ${apiKeyData.bedrockAccountId} is not available, falling back to pool`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有会话哈希,检查是否有已映射的账户
|
||||
if (sessionHash) {
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash);
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash)
|
||||
if (mappedAccount) {
|
||||
// 验证映射的账户是否仍然可用
|
||||
const isAvailable = await this._isAccountAvailable(mappedAccount.accountId, mappedAccount.accountType);
|
||||
const isAvailable = await this._isAccountAvailable(
|
||||
mappedAccount.accountId,
|
||||
mappedAccount.accountType
|
||||
)
|
||||
if (isAvailable) {
|
||||
logger.info(`🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`);
|
||||
return mappedAccount;
|
||||
logger.info(
|
||||
`🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
return mappedAccount
|
||||
} else {
|
||||
logger.warn(`⚠️ Mapped account ${mappedAccount.accountId} is no longer available, selecting new account`);
|
||||
await this._deleteSessionMapping(sessionHash);
|
||||
logger.warn(
|
||||
`⚠️ Mapped account ${mappedAccount.accountId} is no longer available, selecting new account`
|
||||
)
|
||||
await this._deleteSessionMapping(sessionHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有可用账户(传递请求的模型进行过滤)
|
||||
const availableAccounts = await this._getAllAvailableAccounts(apiKeyData, requestedModel);
|
||||
|
||||
const availableAccounts = await this._getAllAvailableAccounts(apiKeyData, requestedModel)
|
||||
|
||||
if (availableAccounts.length === 0) {
|
||||
// 提供更详细的错误信息
|
||||
if (requestedModel) {
|
||||
throw new Error(`No available Claude accounts support the requested model: ${requestedModel}`);
|
||||
throw new Error(
|
||||
`No available Claude accounts support the requested model: ${requestedModel}`
|
||||
)
|
||||
} else {
|
||||
throw new Error('No available Claude accounts (neither official nor console)');
|
||||
throw new Error('No available Claude accounts (neither official nor console)')
|
||||
}
|
||||
}
|
||||
|
||||
// 按优先级和最后使用时间排序
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts);
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts)
|
||||
|
||||
// 选择第一个账户
|
||||
const selectedAccount = sortedAccounts[0];
|
||||
|
||||
const selectedAccount = sortedAccounts[0]
|
||||
|
||||
// 如果有会话哈希,建立新的映射
|
||||
if (sessionHash) {
|
||||
await this._setSessionMapping(sessionHash, selectedAccount.accountId, selectedAccount.accountType);
|
||||
logger.info(`🎯 Created new sticky session mapping: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`);
|
||||
await this._setSessionMapping(
|
||||
sessionHash,
|
||||
selectedAccount.accountId,
|
||||
selectedAccount.accountType
|
||||
)
|
||||
logger.info(
|
||||
`🎯 Created new sticky session mapping: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority} for API key ${apiKeyData.name}`);
|
||||
|
||||
logger.info(
|
||||
`🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority} for API key ${apiKeyData.name}`
|
||||
)
|
||||
|
||||
return {
|
||||
accountId: selectedAccount.accountId,
|
||||
accountType: selectedAccount.accountType
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to select account for API key:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to select account for API key:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📋 获取所有可用账户(合并官方和Console)
|
||||
async _getAllAvailableAccounts(apiKeyData, requestedModel = null) {
|
||||
const availableAccounts = [];
|
||||
const availableAccounts = []
|
||||
|
||||
// 如果API Key绑定了专属账户,优先返回
|
||||
// 1. 检查Claude OAuth账户绑定
|
||||
if (apiKeyData.claudeAccountId) {
|
||||
const boundAccount = await redis.getClaudeAccount(apiKeyData.claudeAccountId);
|
||||
if (boundAccount && boundAccount.isActive === 'true' && boundAccount.status !== 'error' && boundAccount.status !== 'blocked') {
|
||||
const isRateLimited = await claudeAccountService.isAccountRateLimited(boundAccount.id);
|
||||
const boundAccount = await redis.getClaudeAccount(apiKeyData.claudeAccountId)
|
||||
if (
|
||||
boundAccount &&
|
||||
boundAccount.isActive === 'true' &&
|
||||
boundAccount.status !== 'error' &&
|
||||
boundAccount.status !== 'blocked'
|
||||
) {
|
||||
const isRateLimited = await claudeAccountService.isAccountRateLimited(boundAccount.id)
|
||||
if (!isRateLimited) {
|
||||
logger.info(`🎯 Using bound dedicated Claude OAuth account: ${boundAccount.name} (${apiKeyData.claudeAccountId})`);
|
||||
return [{
|
||||
...boundAccount,
|
||||
accountId: boundAccount.id,
|
||||
accountType: 'claude-official',
|
||||
priority: parseInt(boundAccount.priority) || 50,
|
||||
lastUsedAt: boundAccount.lastUsedAt || '0'
|
||||
}];
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Claude OAuth account: ${boundAccount.name} (${apiKeyData.claudeAccountId})`
|
||||
)
|
||||
return [
|
||||
{
|
||||
...boundAccount,
|
||||
accountId: boundAccount.id,
|
||||
accountType: 'claude-official',
|
||||
priority: parseInt(boundAccount.priority) || 50,
|
||||
lastUsedAt: boundAccount.lastUsedAt || '0'
|
||||
}
|
||||
]
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Claude OAuth account ${apiKeyData.claudeAccountId} is not available`);
|
||||
logger.warn(`⚠️ Bound Claude OAuth account ${apiKeyData.claudeAccountId} is not available`)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 2. 检查Claude Console账户绑定
|
||||
if (apiKeyData.claudeConsoleAccountId) {
|
||||
const boundConsoleAccount = await claudeConsoleAccountService.getAccount(apiKeyData.claudeConsoleAccountId);
|
||||
if (boundConsoleAccount && boundConsoleAccount.isActive === true && boundConsoleAccount.status === 'active') {
|
||||
const isRateLimited = await claudeConsoleAccountService.isAccountRateLimited(boundConsoleAccount.id);
|
||||
const boundConsoleAccount = await claudeConsoleAccountService.getAccount(
|
||||
apiKeyData.claudeConsoleAccountId
|
||||
)
|
||||
if (
|
||||
boundConsoleAccount &&
|
||||
boundConsoleAccount.isActive === true &&
|
||||
boundConsoleAccount.status === 'active'
|
||||
) {
|
||||
const isRateLimited = await claudeConsoleAccountService.isAccountRateLimited(
|
||||
boundConsoleAccount.id
|
||||
)
|
||||
if (!isRateLimited) {
|
||||
logger.info(`🎯 Using bound dedicated Claude Console account: ${boundConsoleAccount.name} (${apiKeyData.claudeConsoleAccountId})`);
|
||||
return [{
|
||||
...boundConsoleAccount,
|
||||
accountId: boundConsoleAccount.id,
|
||||
accountType: 'claude-console',
|
||||
priority: parseInt(boundConsoleAccount.priority) || 50,
|
||||
lastUsedAt: boundConsoleAccount.lastUsedAt || '0'
|
||||
}];
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Claude Console account: ${boundConsoleAccount.name} (${apiKeyData.claudeConsoleAccountId})`
|
||||
)
|
||||
return [
|
||||
{
|
||||
...boundConsoleAccount,
|
||||
accountId: boundConsoleAccount.id,
|
||||
accountType: 'claude-console',
|
||||
priority: parseInt(boundConsoleAccount.priority) || 50,
|
||||
lastUsedAt: boundConsoleAccount.lastUsedAt || '0'
|
||||
}
|
||||
]
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Claude Console account ${apiKeyData.claudeConsoleAccountId} is not available`);
|
||||
logger.warn(
|
||||
`⚠️ Bound Claude Console account ${apiKeyData.claudeConsoleAccountId} is not available`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 检查Bedrock账户绑定
|
||||
if (apiKeyData.bedrockAccountId) {
|
||||
const boundBedrockAccountResult = await bedrockAccountService.getAccount(apiKeyData.bedrockAccountId);
|
||||
const boundBedrockAccountResult = await bedrockAccountService.getAccount(
|
||||
apiKeyData.bedrockAccountId
|
||||
)
|
||||
if (boundBedrockAccountResult.success && boundBedrockAccountResult.data.isActive === true) {
|
||||
logger.info(`🎯 Using bound dedicated Bedrock account: ${boundBedrockAccountResult.data.name} (${apiKeyData.bedrockAccountId})`);
|
||||
return [{
|
||||
...boundBedrockAccountResult.data,
|
||||
accountId: boundBedrockAccountResult.data.id,
|
||||
accountType: 'bedrock',
|
||||
priority: parseInt(boundBedrockAccountResult.data.priority) || 50,
|
||||
lastUsedAt: boundBedrockAccountResult.data.lastUsedAt || '0'
|
||||
}];
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Bedrock account: ${boundBedrockAccountResult.data.name} (${apiKeyData.bedrockAccountId})`
|
||||
)
|
||||
return [
|
||||
{
|
||||
...boundBedrockAccountResult.data,
|
||||
accountId: boundBedrockAccountResult.data.id,
|
||||
accountType: 'bedrock',
|
||||
priority: parseInt(boundBedrockAccountResult.data.priority) || 50,
|
||||
lastUsedAt: boundBedrockAccountResult.data.lastUsedAt || '0'
|
||||
}
|
||||
]
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Bedrock account ${apiKeyData.bedrockAccountId} is not available`);
|
||||
logger.warn(`⚠️ Bound Bedrock account ${apiKeyData.bedrockAccountId} is not available`)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取官方Claude账户(共享池)
|
||||
const claudeAccounts = await redis.getAllClaudeAccounts();
|
||||
const claudeAccounts = await redis.getAllClaudeAccounts()
|
||||
for (const account of claudeAccounts) {
|
||||
if (account.isActive === 'true' &&
|
||||
account.status !== 'error' &&
|
||||
account.status !== 'blocked' &&
|
||||
(account.accountType === 'shared' || !account.accountType) && // 兼容旧数据
|
||||
this._isSchedulable(account.schedulable)) { // 检查是否可调度
|
||||
|
||||
if (
|
||||
account.isActive === 'true' &&
|
||||
account.status !== 'error' &&
|
||||
account.status !== 'blocked' &&
|
||||
(account.accountType === 'shared' || !account.accountType) && // 兼容旧数据
|
||||
this._isSchedulable(account.schedulable)
|
||||
) {
|
||||
// 检查是否可调度
|
||||
|
||||
// 检查是否被限流
|
||||
const isRateLimited = await claudeAccountService.isAccountRateLimited(account.id);
|
||||
const isRateLimited = await claudeAccountService.isAccountRateLimited(account.id)
|
||||
if (!isRateLimited) {
|
||||
availableAccounts.push({
|
||||
...account,
|
||||
@@ -205,44 +276,59 @@ class UnifiedClaudeScheduler {
|
||||
accountType: 'claude-official',
|
||||
priority: parseInt(account.priority) || 50, // 默认优先级50
|
||||
lastUsedAt: account.lastUsedAt || '0'
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取Claude Console账户
|
||||
const consoleAccounts = await claudeConsoleAccountService.getAllAccounts();
|
||||
logger.info(`📋 Found ${consoleAccounts.length} total Claude Console accounts`);
|
||||
|
||||
const consoleAccounts = await claudeConsoleAccountService.getAllAccounts()
|
||||
logger.info(`📋 Found ${consoleAccounts.length} total Claude Console accounts`)
|
||||
|
||||
for (const account of consoleAccounts) {
|
||||
logger.info(`🔍 Checking Claude Console account: ${account.name} - isActive: ${account.isActive}, status: ${account.status}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`);
|
||||
|
||||
logger.info(
|
||||
`🔍 Checking Claude Console account: ${account.name} - isActive: ${account.isActive}, status: ${account.status}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`
|
||||
)
|
||||
|
||||
// 注意:getAllAccounts返回的isActive是布尔值
|
||||
if (account.isActive === true &&
|
||||
account.status === 'active' &&
|
||||
account.accountType === 'shared' &&
|
||||
this._isSchedulable(account.schedulable)) { // 检查是否可调度
|
||||
|
||||
if (
|
||||
account.isActive === true &&
|
||||
account.status === 'active' &&
|
||||
account.accountType === 'shared' &&
|
||||
this._isSchedulable(account.schedulable)
|
||||
) {
|
||||
// 检查是否可调度
|
||||
|
||||
// 检查模型支持(如果有请求的模型)
|
||||
if (requestedModel && account.supportedModels) {
|
||||
// 兼容旧格式(数组)和新格式(对象)
|
||||
if (Array.isArray(account.supportedModels)) {
|
||||
// 旧格式:数组
|
||||
if (account.supportedModels.length > 0 && !account.supportedModels.includes(requestedModel)) {
|
||||
logger.info(`🚫 Claude Console account ${account.name} does not support model ${requestedModel}`);
|
||||
continue;
|
||||
if (
|
||||
account.supportedModels.length > 0 &&
|
||||
!account.supportedModels.includes(requestedModel)
|
||||
) {
|
||||
logger.info(
|
||||
`🚫 Claude Console account ${account.name} does not support model ${requestedModel}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
} else if (typeof account.supportedModels === 'object') {
|
||||
// 新格式:映射表
|
||||
if (Object.keys(account.supportedModels).length > 0 && !claudeConsoleAccountService.isModelSupported(account.supportedModels, requestedModel)) {
|
||||
logger.info(`🚫 Claude Console account ${account.name} does not support model ${requestedModel}`);
|
||||
continue;
|
||||
if (
|
||||
Object.keys(account.supportedModels).length > 0 &&
|
||||
!claudeConsoleAccountService.isModelSupported(account.supportedModels, requestedModel)
|
||||
) {
|
||||
logger.info(
|
||||
`🚫 Claude Console account ${account.name} does not support model ${requestedModel}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 检查是否被限流
|
||||
const isRateLimited = await claudeConsoleAccountService.isAccountRateLimited(account.id);
|
||||
const isRateLimited = await claudeConsoleAccountService.isAccountRateLimited(account.id)
|
||||
if (!isRateLimited) {
|
||||
availableAccounts.push({
|
||||
...account,
|
||||
@@ -250,45 +336,60 @@ class UnifiedClaudeScheduler {
|
||||
accountType: 'claude-console',
|
||||
priority: parseInt(account.priority) || 50,
|
||||
lastUsedAt: account.lastUsedAt || '0'
|
||||
});
|
||||
logger.info(`✅ Added Claude Console account to available pool: ${account.name} (priority: ${account.priority})`);
|
||||
})
|
||||
logger.info(
|
||||
`✅ Added Claude Console account to available pool: ${account.name} (priority: ${account.priority})`
|
||||
)
|
||||
} else {
|
||||
logger.warn(`⚠️ Claude Console account ${account.name} is rate limited`);
|
||||
logger.warn(`⚠️ Claude Console account ${account.name} is rate limited`)
|
||||
}
|
||||
} else {
|
||||
logger.info(`❌ Claude Console account ${account.name} not eligible - isActive: ${account.isActive}, status: ${account.status}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`);
|
||||
logger.info(
|
||||
`❌ Claude Console account ${account.name} not eligible - isActive: ${account.isActive}, status: ${account.status}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取Bedrock账户(共享池)
|
||||
const bedrockAccountsResult = await bedrockAccountService.getAllAccounts();
|
||||
const bedrockAccountsResult = await bedrockAccountService.getAllAccounts()
|
||||
if (bedrockAccountsResult.success) {
|
||||
const bedrockAccounts = bedrockAccountsResult.data;
|
||||
logger.info(`📋 Found ${bedrockAccounts.length} total Bedrock accounts`);
|
||||
|
||||
const bedrockAccounts = bedrockAccountsResult.data
|
||||
logger.info(`📋 Found ${bedrockAccounts.length} total Bedrock accounts`)
|
||||
|
||||
for (const account of bedrockAccounts) {
|
||||
logger.info(`🔍 Checking Bedrock account: ${account.name} - isActive: ${account.isActive}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`);
|
||||
|
||||
if (account.isActive === true &&
|
||||
account.accountType === 'shared' &&
|
||||
this._isSchedulable(account.schedulable)) { // 检查是否可调度
|
||||
|
||||
logger.info(
|
||||
`🔍 Checking Bedrock account: ${account.name} - isActive: ${account.isActive}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`
|
||||
)
|
||||
|
||||
if (
|
||||
account.isActive === true &&
|
||||
account.accountType === 'shared' &&
|
||||
this._isSchedulable(account.schedulable)
|
||||
) {
|
||||
// 检查是否可调度
|
||||
|
||||
availableAccounts.push({
|
||||
...account,
|
||||
accountId: account.id,
|
||||
accountType: 'bedrock',
|
||||
priority: parseInt(account.priority) || 50,
|
||||
lastUsedAt: account.lastUsedAt || '0'
|
||||
});
|
||||
logger.info(`✅ Added Bedrock account to available pool: ${account.name} (priority: ${account.priority})`);
|
||||
})
|
||||
logger.info(
|
||||
`✅ Added Bedrock account to available pool: ${account.name} (priority: ${account.priority})`
|
||||
)
|
||||
} else {
|
||||
logger.info(`❌ Bedrock account ${account.name} not eligible - isActive: ${account.isActive}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`);
|
||||
logger.info(
|
||||
`❌ Bedrock account ${account.name} not eligible - isActive: ${account.isActive}, accountType: ${account.accountType}, schedulable: ${account.schedulable}`
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`📊 Total available accounts: ${availableAccounts.length} (Claude: ${availableAccounts.filter(a => a.accountType === 'claude-official').length}, Console: ${availableAccounts.filter(a => a.accountType === 'claude-console').length}, Bedrock: ${availableAccounts.filter(a => a.accountType === 'bedrock').length})`);
|
||||
return availableAccounts;
|
||||
|
||||
logger.info(
|
||||
`📊 Total available accounts: ${availableAccounts.length} (Claude: ${availableAccounts.filter((a) => a.accountType === 'claude-official').length}, Console: ${availableAccounts.filter((a) => a.accountType === 'claude-console').length}, Bedrock: ${availableAccounts.filter((a) => a.accountType === 'bedrock').length})`
|
||||
)
|
||||
return availableAccounts
|
||||
}
|
||||
|
||||
// 🔢 按优先级和最后使用时间排序账户
|
||||
@@ -296,115 +397,123 @@ class UnifiedClaudeScheduler {
|
||||
return accounts.sort((a, b) => {
|
||||
// 首先按优先级排序(数字越小优先级越高)
|
||||
if (a.priority !== b.priority) {
|
||||
return a.priority - b.priority;
|
||||
return a.priority - b.priority
|
||||
}
|
||||
|
||||
|
||||
// 优先级相同时,按最后使用时间排序(最久未使用的优先)
|
||||
const aLastUsed = new Date(a.lastUsedAt || 0).getTime();
|
||||
const bLastUsed = new Date(b.lastUsedAt || 0).getTime();
|
||||
return aLastUsed - bLastUsed;
|
||||
});
|
||||
const aLastUsed = new Date(a.lastUsedAt || 0).getTime()
|
||||
const bLastUsed = new Date(b.lastUsedAt || 0).getTime()
|
||||
return aLastUsed - bLastUsed
|
||||
})
|
||||
}
|
||||
|
||||
// 🔍 检查账户是否可用
|
||||
async _isAccountAvailable(accountId, accountType) {
|
||||
try {
|
||||
if (accountType === 'claude-official') {
|
||||
const account = await redis.getClaudeAccount(accountId);
|
||||
const account = await redis.getClaudeAccount(accountId)
|
||||
if (!account || account.isActive !== 'true' || account.status === 'error') {
|
||||
return false;
|
||||
return false
|
||||
}
|
||||
// 检查是否可调度
|
||||
if (!this._isSchedulable(account.schedulable)) {
|
||||
logger.info(`🚫 Account ${accountId} is not schedulable`);
|
||||
return false;
|
||||
logger.info(`🚫 Account ${accountId} is not schedulable`)
|
||||
return false
|
||||
}
|
||||
return !(await claudeAccountService.isAccountRateLimited(accountId));
|
||||
return !(await claudeAccountService.isAccountRateLimited(accountId))
|
||||
} else if (accountType === 'claude-console') {
|
||||
const account = await claudeConsoleAccountService.getAccount(accountId);
|
||||
const account = await claudeConsoleAccountService.getAccount(accountId)
|
||||
if (!account || !account.isActive || account.status !== 'active') {
|
||||
return false;
|
||||
return false
|
||||
}
|
||||
// 检查是否可调度
|
||||
if (!this._isSchedulable(account.schedulable)) {
|
||||
logger.info(`🚫 Claude Console account ${accountId} is not schedulable`);
|
||||
return false;
|
||||
logger.info(`🚫 Claude Console account ${accountId} is not schedulable`)
|
||||
return false
|
||||
}
|
||||
return !(await claudeConsoleAccountService.isAccountRateLimited(accountId));
|
||||
return !(await claudeConsoleAccountService.isAccountRateLimited(accountId))
|
||||
} else if (accountType === 'bedrock') {
|
||||
const accountResult = await bedrockAccountService.getAccount(accountId);
|
||||
const accountResult = await bedrockAccountService.getAccount(accountId)
|
||||
if (!accountResult.success || !accountResult.data.isActive) {
|
||||
return false;
|
||||
return false
|
||||
}
|
||||
// 检查是否可调度
|
||||
if (!this._isSchedulable(accountResult.data.schedulable)) {
|
||||
logger.info(`🚫 Bedrock account ${accountId} is not schedulable`);
|
||||
return false;
|
||||
logger.info(`🚫 Bedrock account ${accountId} is not schedulable`)
|
||||
return false
|
||||
}
|
||||
// Bedrock账户暂不需要限流检查,因为AWS管理限流
|
||||
return true;
|
||||
return true
|
||||
}
|
||||
return false;
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.warn(`⚠️ Failed to check account availability: ${accountId}`, error);
|
||||
return false;
|
||||
logger.warn(`⚠️ Failed to check account availability: ${accountId}`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 🔗 获取会话映射
|
||||
async _getSessionMapping(sessionHash) {
|
||||
const client = redis.getClientSafe();
|
||||
const mappingData = await client.get(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const mappingData = await client.get(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`)
|
||||
|
||||
if (mappingData) {
|
||||
try {
|
||||
return JSON.parse(mappingData);
|
||||
return JSON.parse(mappingData)
|
||||
} catch (error) {
|
||||
logger.warn('⚠️ Failed to parse session mapping:', error);
|
||||
return null;
|
||||
logger.warn('⚠️ Failed to parse session mapping:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
// 💾 设置会话映射
|
||||
async _setSessionMapping(sessionHash, accountId, accountType) {
|
||||
const client = redis.getClientSafe();
|
||||
const mappingData = JSON.stringify({ accountId, accountType });
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const mappingData = JSON.stringify({ accountId, accountType })
|
||||
|
||||
// 设置1小时过期
|
||||
await client.setex(
|
||||
`${this.SESSION_MAPPING_PREFIX}${sessionHash}`,
|
||||
3600,
|
||||
mappingData
|
||||
);
|
||||
await client.setex(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`, 3600, mappingData)
|
||||
}
|
||||
|
||||
// 🗑️ 删除会话映射
|
||||
async _deleteSessionMapping(sessionHash) {
|
||||
const client = redis.getClientSafe();
|
||||
await client.del(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`);
|
||||
const client = redis.getClientSafe()
|
||||
await client.del(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`)
|
||||
}
|
||||
|
||||
// 🚫 标记账户为限流状态
|
||||
async markAccountRateLimited(accountId, accountType, sessionHash = null, rateLimitResetTimestamp = null) {
|
||||
async markAccountRateLimited(
|
||||
accountId,
|
||||
accountType,
|
||||
sessionHash = null,
|
||||
rateLimitResetTimestamp = null
|
||||
) {
|
||||
try {
|
||||
if (accountType === 'claude-official') {
|
||||
await claudeAccountService.markAccountRateLimited(accountId, sessionHash, rateLimitResetTimestamp);
|
||||
await claudeAccountService.markAccountRateLimited(
|
||||
accountId,
|
||||
sessionHash,
|
||||
rateLimitResetTimestamp
|
||||
)
|
||||
} else if (accountType === 'claude-console') {
|
||||
await claudeConsoleAccountService.markAccountRateLimited(accountId);
|
||||
await claudeConsoleAccountService.markAccountRateLimited(accountId)
|
||||
}
|
||||
|
||||
// 删除会话映射
|
||||
if (sessionHash) {
|
||||
await this._deleteSessionMapping(sessionHash);
|
||||
await this._deleteSessionMapping(sessionHash)
|
||||
}
|
||||
|
||||
return { success: true };
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to mark account as rate limited: ${accountId} (${accountType})`, error);
|
||||
throw error;
|
||||
logger.error(
|
||||
`❌ Failed to mark account as rate limited: ${accountId} (${accountType})`,
|
||||
error
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -412,15 +521,18 @@ class UnifiedClaudeScheduler {
|
||||
async removeAccountRateLimit(accountId, accountType) {
|
||||
try {
|
||||
if (accountType === 'claude-official') {
|
||||
await claudeAccountService.removeAccountRateLimit(accountId);
|
||||
await claudeAccountService.removeAccountRateLimit(accountId)
|
||||
} else if (accountType === 'claude-console') {
|
||||
await claudeConsoleAccountService.removeAccountRateLimit(accountId);
|
||||
await claudeConsoleAccountService.removeAccountRateLimit(accountId)
|
||||
}
|
||||
|
||||
return { success: true };
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to remove rate limit for account: ${accountId} (${accountType})`, error);
|
||||
throw error;
|
||||
logger.error(
|
||||
`❌ Failed to remove rate limit for account: ${accountId} (${accountType})`,
|
||||
error
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -428,25 +540,25 @@ class UnifiedClaudeScheduler {
|
||||
async isAccountRateLimited(accountId, accountType) {
|
||||
try {
|
||||
if (accountType === 'claude-official') {
|
||||
return await claudeAccountService.isAccountRateLimited(accountId);
|
||||
return await claudeAccountService.isAccountRateLimited(accountId)
|
||||
} else if (accountType === 'claude-console') {
|
||||
return await claudeConsoleAccountService.isAccountRateLimited(accountId);
|
||||
return await claudeConsoleAccountService.isAccountRateLimited(accountId)
|
||||
}
|
||||
return false;
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to check rate limit status: ${accountId} (${accountType})`, error);
|
||||
return false;
|
||||
logger.error(`❌ Failed to check rate limit status: ${accountId} (${accountType})`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 🚫 标记Claude Console账户为封锁状态(模型不支持)
|
||||
async blockConsoleAccount(accountId, reason) {
|
||||
try {
|
||||
await claudeConsoleAccountService.blockAccount(accountId, reason);
|
||||
return { success: true };
|
||||
await claudeConsoleAccountService.blockAccount(accountId, reason)
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to block console account: ${accountId}`, error);
|
||||
throw error;
|
||||
logger.error(`❌ Failed to block console account: ${accountId}`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,127 +566,149 @@ class UnifiedClaudeScheduler {
|
||||
async selectAccountFromGroup(groupId, sessionHash = null, requestedModel = null) {
|
||||
try {
|
||||
// 获取分组信息
|
||||
const group = await accountGroupService.getGroup(groupId);
|
||||
const group = await accountGroupService.getGroup(groupId)
|
||||
if (!group) {
|
||||
throw new Error(`Group ${groupId} not found`);
|
||||
throw new Error(`Group ${groupId} not found`)
|
||||
}
|
||||
|
||||
logger.info(`👥 Selecting account from group: ${group.name} (${group.platform})`);
|
||||
logger.info(`👥 Selecting account from group: ${group.name} (${group.platform})`)
|
||||
|
||||
// 如果有会话哈希,检查是否有已映射的账户
|
||||
if (sessionHash) {
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash);
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash)
|
||||
if (mappedAccount) {
|
||||
// 验证映射的账户是否属于这个分组
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId);
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId)
|
||||
if (memberIds.includes(mappedAccount.accountId)) {
|
||||
const isAvailable = await this._isAccountAvailable(mappedAccount.accountId, mappedAccount.accountType);
|
||||
const isAvailable = await this._isAccountAvailable(
|
||||
mappedAccount.accountId,
|
||||
mappedAccount.accountType
|
||||
)
|
||||
if (isAvailable) {
|
||||
logger.info(`🎯 Using sticky session account from group: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`);
|
||||
return mappedAccount;
|
||||
logger.info(
|
||||
`🎯 Using sticky session account from group: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
return mappedAccount
|
||||
}
|
||||
}
|
||||
// 如果映射的账户不可用或不在分组中,删除映射
|
||||
await this._deleteSessionMapping(sessionHash);
|
||||
await this._deleteSessionMapping(sessionHash)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取分组内的所有账户
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId);
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId)
|
||||
if (memberIds.length === 0) {
|
||||
throw new Error(`Group ${group.name} has no members`);
|
||||
throw new Error(`Group ${group.name} has no members`)
|
||||
}
|
||||
|
||||
const availableAccounts = [];
|
||||
const availableAccounts = []
|
||||
|
||||
// 获取所有成员账户的详细信息
|
||||
for (const memberId of memberIds) {
|
||||
let account = null;
|
||||
let accountType = null;
|
||||
let account = null
|
||||
let accountType = null
|
||||
|
||||
// 根据平台类型获取账户
|
||||
if (group.platform === 'claude') {
|
||||
// 先尝试官方账户
|
||||
account = await redis.getClaudeAccount(memberId);
|
||||
account = await redis.getClaudeAccount(memberId)
|
||||
if (account?.id) {
|
||||
accountType = 'claude-official';
|
||||
accountType = 'claude-official'
|
||||
} else {
|
||||
// 尝试Console账户
|
||||
account = await claudeConsoleAccountService.getAccount(memberId);
|
||||
account = await claudeConsoleAccountService.getAccount(memberId)
|
||||
if (account) {
|
||||
accountType = 'claude-console';
|
||||
accountType = 'claude-console'
|
||||
}
|
||||
}
|
||||
} else if (group.platform === 'gemini') {
|
||||
// Gemini暂时不支持,预留接口
|
||||
logger.warn('⚠️ Gemini group scheduling not yet implemented');
|
||||
continue;
|
||||
logger.warn('⚠️ Gemini group scheduling not yet implemented')
|
||||
continue
|
||||
}
|
||||
|
||||
if (!account) {
|
||||
logger.warn(`⚠️ Account ${memberId} not found in group ${group.name}`);
|
||||
continue;
|
||||
logger.warn(`⚠️ Account ${memberId} not found in group ${group.name}`)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查账户是否可用
|
||||
const isActive = accountType === 'claude-official'
|
||||
? account.isActive === 'true'
|
||||
: account.isActive === true;
|
||||
|
||||
const status = accountType === 'claude-official'
|
||||
? account.status !== 'error' && account.status !== 'blocked'
|
||||
: account.status === 'active';
|
||||
const isActive =
|
||||
accountType === 'claude-official'
|
||||
? account.isActive === 'true'
|
||||
: account.isActive === true
|
||||
|
||||
const status =
|
||||
accountType === 'claude-official'
|
||||
? account.status !== 'error' && account.status !== 'blocked'
|
||||
: account.status === 'active'
|
||||
|
||||
if (isActive && status && this._isSchedulable(account.schedulable)) {
|
||||
// 检查模型支持(Console账户)
|
||||
if (accountType === 'claude-console' && requestedModel && account.supportedModels && account.supportedModels.length > 0) {
|
||||
if (
|
||||
accountType === 'claude-console' &&
|
||||
requestedModel &&
|
||||
account.supportedModels &&
|
||||
account.supportedModels.length > 0
|
||||
) {
|
||||
if (!account.supportedModels.includes(requestedModel)) {
|
||||
logger.info(`🚫 Account ${account.name} in group does not support model ${requestedModel}`);
|
||||
continue;
|
||||
logger.info(
|
||||
`🚫 Account ${account.name} in group does not support model ${requestedModel}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否被限流
|
||||
const isRateLimited = await this.isAccountRateLimited(account.id, accountType);
|
||||
const isRateLimited = await this.isAccountRateLimited(account.id, accountType)
|
||||
if (!isRateLimited) {
|
||||
availableAccounts.push({
|
||||
...account,
|
||||
accountId: account.id,
|
||||
accountType: accountType,
|
||||
accountType,
|
||||
priority: parseInt(account.priority) || 50,
|
||||
lastUsedAt: account.lastUsedAt || '0'
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (availableAccounts.length === 0) {
|
||||
throw new Error(`No available accounts in group ${group.name}`);
|
||||
throw new Error(`No available accounts in group ${group.name}`)
|
||||
}
|
||||
|
||||
// 使用现有的优先级排序逻辑
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts);
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts)
|
||||
|
||||
// 选择第一个账户
|
||||
const selectedAccount = sortedAccounts[0];
|
||||
const selectedAccount = sortedAccounts[0]
|
||||
|
||||
// 如果有会话哈希,建立新的映射
|
||||
if (sessionHash) {
|
||||
await this._setSessionMapping(sessionHash, selectedAccount.accountId, selectedAccount.accountType);
|
||||
logger.info(`🎯 Created new sticky session mapping in group: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`);
|
||||
await this._setSessionMapping(
|
||||
sessionHash,
|
||||
selectedAccount.accountId,
|
||||
selectedAccount.accountType
|
||||
)
|
||||
logger.info(
|
||||
`🎯 Created new sticky session mapping in group: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`🎯 Selected account from group ${group.name}: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority}`);
|
||||
|
||||
logger.info(
|
||||
`🎯 Selected account from group ${group.name}: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority}`
|
||||
)
|
||||
|
||||
return {
|
||||
accountId: selectedAccount.accountId,
|
||||
accountType: selectedAccount.accountType
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to select account from group ${groupId}:`, error);
|
||||
throw error;
|
||||
logger.error(`❌ Failed to select account from group ${groupId}:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new UnifiedClaudeScheduler();
|
||||
module.exports = new UnifiedClaudeScheduler()
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
const geminiAccountService = require('./geminiAccountService');
|
||||
const accountGroupService = require('./accountGroupService');
|
||||
const redis = require('../models/redis');
|
||||
const logger = require('../utils/logger');
|
||||
const geminiAccountService = require('./geminiAccountService')
|
||||
const accountGroupService = require('./accountGroupService')
|
||||
const redis = require('../models/redis')
|
||||
const logger = require('../utils/logger')
|
||||
|
||||
class UnifiedGeminiScheduler {
|
||||
constructor() {
|
||||
this.SESSION_MAPPING_PREFIX = 'unified_gemini_session_mapping:';
|
||||
this.SESSION_MAPPING_PREFIX = 'unified_gemini_session_mapping:'
|
||||
}
|
||||
|
||||
// 🔧 辅助方法:检查账户是否可调度(兼容字符串和布尔值)
|
||||
_isSchedulable(schedulable) {
|
||||
// 如果是 undefined 或 null,默认为可调度
|
||||
if (schedulable === undefined || schedulable === null) {
|
||||
return true;
|
||||
return true
|
||||
}
|
||||
// 明确设置为 false(布尔值)或 'false'(字符串)时不可调度
|
||||
return schedulable !== false && schedulable !== 'false';
|
||||
return schedulable !== false && schedulable !== 'false'
|
||||
}
|
||||
|
||||
// 🎯 统一调度Gemini账号
|
||||
@@ -25,143 +25,183 @@ class UnifiedGeminiScheduler {
|
||||
if (apiKeyData.geminiAccountId) {
|
||||
// 检查是否是分组
|
||||
if (apiKeyData.geminiAccountId.startsWith('group:')) {
|
||||
const groupId = apiKeyData.geminiAccountId.replace('group:', '');
|
||||
logger.info(`🎯 API key ${apiKeyData.name} is bound to group ${groupId}, selecting from group`);
|
||||
return await this.selectAccountFromGroup(groupId, sessionHash, requestedModel, apiKeyData);
|
||||
const groupId = apiKeyData.geminiAccountId.replace('group:', '')
|
||||
logger.info(
|
||||
`🎯 API key ${apiKeyData.name} is bound to group ${groupId}, selecting from group`
|
||||
)
|
||||
return await this.selectAccountFromGroup(groupId, sessionHash, requestedModel, apiKeyData)
|
||||
}
|
||||
|
||||
|
||||
// 普通专属账户
|
||||
const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId);
|
||||
const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId)
|
||||
if (boundAccount && boundAccount.isActive === 'true' && boundAccount.status !== 'error') {
|
||||
logger.info(`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId}) for API key ${apiKeyData.name}`);
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId}) for API key ${apiKeyData.name}`
|
||||
)
|
||||
return {
|
||||
accountId: apiKeyData.geminiAccountId,
|
||||
accountType: 'gemini'
|
||||
};
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Gemini account ${apiKeyData.geminiAccountId} is not available, falling back to pool`);
|
||||
logger.warn(
|
||||
`⚠️ Bound Gemini account ${apiKeyData.geminiAccountId} is not available, falling back to pool`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有会话哈希,检查是否有已映射的账户
|
||||
if (sessionHash) {
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash);
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash)
|
||||
if (mappedAccount) {
|
||||
// 验证映射的账户是否仍然可用
|
||||
const isAvailable = await this._isAccountAvailable(mappedAccount.accountId, mappedAccount.accountType);
|
||||
const isAvailable = await this._isAccountAvailable(
|
||||
mappedAccount.accountId,
|
||||
mappedAccount.accountType
|
||||
)
|
||||
if (isAvailable) {
|
||||
logger.info(`🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`);
|
||||
return mappedAccount;
|
||||
logger.info(
|
||||
`🎯 Using sticky session account: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
return mappedAccount
|
||||
} else {
|
||||
logger.warn(`⚠️ Mapped account ${mappedAccount.accountId} is no longer available, selecting new account`);
|
||||
await this._deleteSessionMapping(sessionHash);
|
||||
logger.warn(
|
||||
`⚠️ Mapped account ${mappedAccount.accountId} is no longer available, selecting new account`
|
||||
)
|
||||
await this._deleteSessionMapping(sessionHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有可用账户
|
||||
const availableAccounts = await this._getAllAvailableAccounts(apiKeyData, requestedModel);
|
||||
|
||||
const availableAccounts = await this._getAllAvailableAccounts(apiKeyData, requestedModel)
|
||||
|
||||
if (availableAccounts.length === 0) {
|
||||
// 提供更详细的错误信息
|
||||
if (requestedModel) {
|
||||
throw new Error(`No available Gemini accounts support the requested model: ${requestedModel}`);
|
||||
throw new Error(
|
||||
`No available Gemini accounts support the requested model: ${requestedModel}`
|
||||
)
|
||||
} else {
|
||||
throw new Error('No available Gemini accounts');
|
||||
throw new Error('No available Gemini accounts')
|
||||
}
|
||||
}
|
||||
|
||||
// 按优先级和最后使用时间排序
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts);
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts)
|
||||
|
||||
// 选择第一个账户
|
||||
const selectedAccount = sortedAccounts[0];
|
||||
|
||||
const selectedAccount = sortedAccounts[0]
|
||||
|
||||
// 如果有会话哈希,建立新的映射
|
||||
if (sessionHash) {
|
||||
await this._setSessionMapping(sessionHash, selectedAccount.accountId, selectedAccount.accountType);
|
||||
logger.info(`🎯 Created new sticky session mapping: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`);
|
||||
await this._setSessionMapping(
|
||||
sessionHash,
|
||||
selectedAccount.accountId,
|
||||
selectedAccount.accountType
|
||||
)
|
||||
logger.info(
|
||||
`🎯 Created new sticky session mapping: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority} for API key ${apiKeyData.name}`);
|
||||
|
||||
logger.info(
|
||||
`🎯 Selected account: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority} for API key ${apiKeyData.name}`
|
||||
)
|
||||
|
||||
return {
|
||||
accountId: selectedAccount.accountId,
|
||||
accountType: selectedAccount.accountType
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('❌ Failed to select account for API key:', error);
|
||||
throw error;
|
||||
logger.error('❌ Failed to select account for API key:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 📋 获取所有可用账户
|
||||
async _getAllAvailableAccounts(apiKeyData, requestedModel = null) {
|
||||
const availableAccounts = [];
|
||||
const availableAccounts = []
|
||||
|
||||
// 如果API Key绑定了专属账户,优先返回
|
||||
if (apiKeyData.geminiAccountId) {
|
||||
const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId);
|
||||
const boundAccount = await geminiAccountService.getAccount(apiKeyData.geminiAccountId)
|
||||
if (boundAccount && boundAccount.isActive === 'true' && boundAccount.status !== 'error') {
|
||||
const isRateLimited = await this.isAccountRateLimited(boundAccount.id);
|
||||
const isRateLimited = await this.isAccountRateLimited(boundAccount.id)
|
||||
if (!isRateLimited) {
|
||||
// 检查模型支持
|
||||
if (requestedModel && boundAccount.supportedModels && boundAccount.supportedModels.length > 0) {
|
||||
if (
|
||||
requestedModel &&
|
||||
boundAccount.supportedModels &&
|
||||
boundAccount.supportedModels.length > 0
|
||||
) {
|
||||
// 处理可能带有 models/ 前缀的模型名
|
||||
const normalizedModel = requestedModel.replace('models/', '');
|
||||
const modelSupported = boundAccount.supportedModels.some(model =>
|
||||
model.replace('models/', '') === normalizedModel
|
||||
);
|
||||
const normalizedModel = requestedModel.replace('models/', '')
|
||||
const modelSupported = boundAccount.supportedModels.some(
|
||||
(model) => model.replace('models/', '') === normalizedModel
|
||||
)
|
||||
if (!modelSupported) {
|
||||
logger.warn(`⚠️ Bound Gemini account ${boundAccount.name} does not support model ${requestedModel}`);
|
||||
return availableAccounts;
|
||||
logger.warn(
|
||||
`⚠️ Bound Gemini account ${boundAccount.name} does not support model ${requestedModel}`
|
||||
)
|
||||
return availableAccounts
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId})`);
|
||||
return [{
|
||||
...boundAccount,
|
||||
accountId: boundAccount.id,
|
||||
accountType: 'gemini',
|
||||
priority: parseInt(boundAccount.priority) || 50,
|
||||
lastUsedAt: boundAccount.lastUsedAt || '0'
|
||||
}];
|
||||
|
||||
logger.info(
|
||||
`🎯 Using bound dedicated Gemini account: ${boundAccount.name} (${apiKeyData.geminiAccountId})`
|
||||
)
|
||||
return [
|
||||
{
|
||||
...boundAccount,
|
||||
accountId: boundAccount.id,
|
||||
accountType: 'gemini',
|
||||
priority: parseInt(boundAccount.priority) || 50,
|
||||
lastUsedAt: boundAccount.lastUsedAt || '0'
|
||||
}
|
||||
]
|
||||
}
|
||||
} else {
|
||||
logger.warn(`⚠️ Bound Gemini account ${apiKeyData.geminiAccountId} is not available`);
|
||||
logger.warn(`⚠️ Bound Gemini account ${apiKeyData.geminiAccountId} is not available`)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取所有Gemini账户(共享池)
|
||||
const geminiAccounts = await geminiAccountService.getAllAccounts();
|
||||
const geminiAccounts = await geminiAccountService.getAllAccounts()
|
||||
for (const account of geminiAccounts) {
|
||||
if (account.isActive === 'true' &&
|
||||
account.status !== 'error' &&
|
||||
(account.accountType === 'shared' || !account.accountType) && // 兼容旧数据
|
||||
this._isSchedulable(account.schedulable)) { // 检查是否可调度
|
||||
|
||||
if (
|
||||
account.isActive === 'true' &&
|
||||
account.status !== 'error' &&
|
||||
(account.accountType === 'shared' || !account.accountType) && // 兼容旧数据
|
||||
this._isSchedulable(account.schedulable)
|
||||
) {
|
||||
// 检查是否可调度
|
||||
|
||||
// 检查token是否过期
|
||||
const isExpired = geminiAccountService.isTokenExpired(account);
|
||||
const isExpired = geminiAccountService.isTokenExpired(account)
|
||||
if (isExpired && !account.refreshToken) {
|
||||
logger.warn(`⚠️ Gemini account ${account.name} token expired and no refresh token available`);
|
||||
continue;
|
||||
logger.warn(
|
||||
`⚠️ Gemini account ${account.name} token expired and no refresh token available`
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// 检查模型支持
|
||||
if (requestedModel && account.supportedModels && account.supportedModels.length > 0) {
|
||||
// 处理可能带有 models/ 前缀的模型名
|
||||
const normalizedModel = requestedModel.replace('models/', '');
|
||||
const modelSupported = account.supportedModels.some(model =>
|
||||
model.replace('models/', '') === normalizedModel
|
||||
);
|
||||
const normalizedModel = requestedModel.replace('models/', '')
|
||||
const modelSupported = account.supportedModels.some(
|
||||
(model) => model.replace('models/', '') === normalizedModel
|
||||
)
|
||||
if (!modelSupported) {
|
||||
logger.debug(`⏭️ Skipping Gemini account ${account.name} - doesn't support model ${requestedModel}`);
|
||||
continue;
|
||||
logger.debug(
|
||||
`⏭️ Skipping Gemini account ${account.name} - doesn't support model ${requestedModel}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 检查是否被限流
|
||||
const isRateLimited = await this.isAccountRateLimited(account.id);
|
||||
const isRateLimited = await this.isAccountRateLimited(account.id)
|
||||
if (!isRateLimited) {
|
||||
availableAccounts.push({
|
||||
...account,
|
||||
@@ -169,13 +209,13 @@ class UnifiedGeminiScheduler {
|
||||
accountType: 'gemini',
|
||||
priority: parseInt(account.priority) || 50, // 默认优先级50
|
||||
lastUsedAt: account.lastUsedAt || '0'
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`📊 Total available Gemini accounts: ${availableAccounts.length}`);
|
||||
return availableAccounts;
|
||||
|
||||
logger.info(`📊 Total available Gemini accounts: ${availableAccounts.length}`)
|
||||
return availableAccounts
|
||||
}
|
||||
|
||||
// 🔢 按优先级和最后使用时间排序账户
|
||||
@@ -183,90 +223,89 @@ class UnifiedGeminiScheduler {
|
||||
return accounts.sort((a, b) => {
|
||||
// 首先按优先级排序(数字越小优先级越高)
|
||||
if (a.priority !== b.priority) {
|
||||
return a.priority - b.priority;
|
||||
return a.priority - b.priority
|
||||
}
|
||||
|
||||
|
||||
// 优先级相同时,按最后使用时间排序(最久未使用的优先)
|
||||
const aLastUsed = new Date(a.lastUsedAt || 0).getTime();
|
||||
const bLastUsed = new Date(b.lastUsedAt || 0).getTime();
|
||||
return aLastUsed - bLastUsed;
|
||||
});
|
||||
const aLastUsed = new Date(a.lastUsedAt || 0).getTime()
|
||||
const bLastUsed = new Date(b.lastUsedAt || 0).getTime()
|
||||
return aLastUsed - bLastUsed
|
||||
})
|
||||
}
|
||||
|
||||
// 🔍 检查账户是否可用
|
||||
async _isAccountAvailable(accountId, accountType) {
|
||||
try {
|
||||
if (accountType === 'gemini') {
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
if (!account || account.isActive !== 'true' || account.status === 'error') {
|
||||
return false;
|
||||
return false
|
||||
}
|
||||
// 检查是否可调度
|
||||
if (!this._isSchedulable(account.schedulable)) {
|
||||
logger.info(`🚫 Gemini account ${accountId} is not schedulable`);
|
||||
return false;
|
||||
logger.info(`🚫 Gemini account ${accountId} is not schedulable`)
|
||||
return false
|
||||
}
|
||||
return !(await this.isAccountRateLimited(accountId));
|
||||
return !(await this.isAccountRateLimited(accountId))
|
||||
}
|
||||
return false;
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.warn(`⚠️ Failed to check account availability: ${accountId}`, error);
|
||||
return false;
|
||||
logger.warn(`⚠️ Failed to check account availability: ${accountId}`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 🔗 获取会话映射
|
||||
async _getSessionMapping(sessionHash) {
|
||||
const client = redis.getClientSafe();
|
||||
const mappingData = await client.get(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`);
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const mappingData = await client.get(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`)
|
||||
|
||||
if (mappingData) {
|
||||
try {
|
||||
return JSON.parse(mappingData);
|
||||
return JSON.parse(mappingData)
|
||||
} catch (error) {
|
||||
logger.warn('⚠️ Failed to parse session mapping:', error);
|
||||
return null;
|
||||
logger.warn('⚠️ Failed to parse session mapping:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
// 💾 设置会话映射
|
||||
async _setSessionMapping(sessionHash, accountId, accountType) {
|
||||
const client = redis.getClientSafe();
|
||||
const mappingData = JSON.stringify({ accountId, accountType });
|
||||
|
||||
const client = redis.getClientSafe()
|
||||
const mappingData = JSON.stringify({ accountId, accountType })
|
||||
|
||||
// 设置1小时过期
|
||||
await client.setex(
|
||||
`${this.SESSION_MAPPING_PREFIX}${sessionHash}`,
|
||||
3600,
|
||||
mappingData
|
||||
);
|
||||
await client.setex(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`, 3600, mappingData)
|
||||
}
|
||||
|
||||
// 🗑️ 删除会话映射
|
||||
async _deleteSessionMapping(sessionHash) {
|
||||
const client = redis.getClientSafe();
|
||||
await client.del(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`);
|
||||
const client = redis.getClientSafe()
|
||||
await client.del(`${this.SESSION_MAPPING_PREFIX}${sessionHash}`)
|
||||
}
|
||||
|
||||
// 🚫 标记账户为限流状态
|
||||
async markAccountRateLimited(accountId, accountType, sessionHash = null) {
|
||||
try {
|
||||
if (accountType === 'gemini') {
|
||||
await geminiAccountService.setAccountRateLimited(accountId, true);
|
||||
await geminiAccountService.setAccountRateLimited(accountId, true)
|
||||
}
|
||||
|
||||
// 删除会话映射
|
||||
if (sessionHash) {
|
||||
await this._deleteSessionMapping(sessionHash);
|
||||
await this._deleteSessionMapping(sessionHash)
|
||||
}
|
||||
|
||||
return { success: true };
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to mark account as rate limited: ${accountId} (${accountType})`, error);
|
||||
throw error;
|
||||
logger.error(
|
||||
`❌ Failed to mark account as rate limited: ${accountId} (${accountType})`,
|
||||
error
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,33 +313,38 @@ class UnifiedGeminiScheduler {
|
||||
async removeAccountRateLimit(accountId, accountType) {
|
||||
try {
|
||||
if (accountType === 'gemini') {
|
||||
await geminiAccountService.setAccountRateLimited(accountId, false);
|
||||
await geminiAccountService.setAccountRateLimited(accountId, false)
|
||||
}
|
||||
|
||||
return { success: true };
|
||||
return { success: true }
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to remove rate limit for account: ${accountId} (${accountType})`, error);
|
||||
throw error;
|
||||
logger.error(
|
||||
`❌ Failed to remove rate limit for account: ${accountId} (${accountType})`,
|
||||
error
|
||||
)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 🔍 检查账户是否处于限流状态
|
||||
async isAccountRateLimited(accountId) {
|
||||
try {
|
||||
const account = await geminiAccountService.getAccount(accountId);
|
||||
if (!account) return false;
|
||||
|
||||
if (account.rateLimitStatus === 'limited' && account.rateLimitedAt) {
|
||||
const limitedAt = new Date(account.rateLimitedAt).getTime();
|
||||
const now = Date.now();
|
||||
const limitDuration = 60 * 60 * 1000; // 1小时
|
||||
|
||||
return now < (limitedAt + limitDuration);
|
||||
const account = await geminiAccountService.getAccount(accountId)
|
||||
if (!account) {
|
||||
return false
|
||||
}
|
||||
return false;
|
||||
|
||||
if (account.rateLimitStatus === 'limited' && account.rateLimitedAt) {
|
||||
const limitedAt = new Date(account.rateLimitedAt).getTime()
|
||||
const now = Date.now()
|
||||
const limitDuration = 60 * 60 * 1000 // 1小时
|
||||
|
||||
return now < limitedAt + limitDuration
|
||||
}
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to check rate limit status: ${accountId}`, error);
|
||||
return false;
|
||||
logger.error(`❌ Failed to check rate limit status: ${accountId}`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,79 +352,89 @@ class UnifiedGeminiScheduler {
|
||||
async selectAccountFromGroup(groupId, sessionHash = null, requestedModel = null) {
|
||||
try {
|
||||
// 获取分组信息
|
||||
const group = await accountGroupService.getGroup(groupId);
|
||||
const group = await accountGroupService.getGroup(groupId)
|
||||
if (!group) {
|
||||
throw new Error(`Group ${groupId} not found`);
|
||||
}
|
||||
|
||||
if (group.platform !== 'gemini') {
|
||||
throw new Error(`Group ${group.name} is not a Gemini group`);
|
||||
throw new Error(`Group ${groupId} not found`)
|
||||
}
|
||||
|
||||
logger.info(`👥 Selecting account from Gemini group: ${group.name}`);
|
||||
if (group.platform !== 'gemini') {
|
||||
throw new Error(`Group ${group.name} is not a Gemini group`)
|
||||
}
|
||||
|
||||
logger.info(`👥 Selecting account from Gemini group: ${group.name}`)
|
||||
|
||||
// 如果有会话哈希,检查是否有已映射的账户
|
||||
if (sessionHash) {
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash);
|
||||
const mappedAccount = await this._getSessionMapping(sessionHash)
|
||||
if (mappedAccount) {
|
||||
// 验证映射的账户是否属于这个分组
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId);
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId)
|
||||
if (memberIds.includes(mappedAccount.accountId)) {
|
||||
const isAvailable = await this._isAccountAvailable(mappedAccount.accountId, mappedAccount.accountType);
|
||||
const isAvailable = await this._isAccountAvailable(
|
||||
mappedAccount.accountId,
|
||||
mappedAccount.accountType
|
||||
)
|
||||
if (isAvailable) {
|
||||
logger.info(`🎯 Using sticky session account from group: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`);
|
||||
return mappedAccount;
|
||||
logger.info(
|
||||
`🎯 Using sticky session account from group: ${mappedAccount.accountId} (${mappedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
return mappedAccount
|
||||
}
|
||||
}
|
||||
// 如果映射的账户不可用或不在分组中,删除映射
|
||||
await this._deleteSessionMapping(sessionHash);
|
||||
await this._deleteSessionMapping(sessionHash)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取分组内的所有账户
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId);
|
||||
const memberIds = await accountGroupService.getGroupMembers(groupId)
|
||||
if (memberIds.length === 0) {
|
||||
throw new Error(`Group ${group.name} has no members`);
|
||||
throw new Error(`Group ${group.name} has no members`)
|
||||
}
|
||||
|
||||
const availableAccounts = [];
|
||||
const availableAccounts = []
|
||||
|
||||
// 获取所有成员账户的详细信息
|
||||
for (const memberId of memberIds) {
|
||||
const account = await geminiAccountService.getAccount(memberId);
|
||||
|
||||
const account = await geminiAccountService.getAccount(memberId)
|
||||
|
||||
if (!account) {
|
||||
logger.warn(`⚠️ Gemini account ${memberId} not found in group ${group.name}`);
|
||||
continue;
|
||||
logger.warn(`⚠️ Gemini account ${memberId} not found in group ${group.name}`)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查账户是否可用
|
||||
if (account.isActive === 'true' &&
|
||||
account.status !== 'error' &&
|
||||
this._isSchedulable(account.schedulable)) {
|
||||
|
||||
if (
|
||||
account.isActive === 'true' &&
|
||||
account.status !== 'error' &&
|
||||
this._isSchedulable(account.schedulable)
|
||||
) {
|
||||
// 检查token是否过期
|
||||
const isExpired = geminiAccountService.isTokenExpired(account);
|
||||
const isExpired = geminiAccountService.isTokenExpired(account)
|
||||
if (isExpired && !account.refreshToken) {
|
||||
logger.warn(`⚠️ Gemini account ${account.name} in group token expired and no refresh token available`);
|
||||
continue;
|
||||
logger.warn(
|
||||
`⚠️ Gemini account ${account.name} in group token expired and no refresh token available`
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查模型支持
|
||||
if (requestedModel && account.supportedModels && account.supportedModels.length > 0) {
|
||||
// 处理可能带有 models/ 前缀的模型名
|
||||
const normalizedModel = requestedModel.replace('models/', '');
|
||||
const modelSupported = account.supportedModels.some(model =>
|
||||
model.replace('models/', '') === normalizedModel
|
||||
);
|
||||
const normalizedModel = requestedModel.replace('models/', '')
|
||||
const modelSupported = account.supportedModels.some(
|
||||
(model) => model.replace('models/', '') === normalizedModel
|
||||
)
|
||||
if (!modelSupported) {
|
||||
logger.debug(`⏭️ Skipping Gemini account ${account.name} in group - doesn't support model ${requestedModel}`);
|
||||
continue;
|
||||
logger.debug(
|
||||
`⏭️ Skipping Gemini account ${account.name} in group - doesn't support model ${requestedModel}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 检查是否被限流
|
||||
const isRateLimited = await this.isAccountRateLimited(account.id);
|
||||
const isRateLimited = await this.isAccountRateLimited(account.id)
|
||||
if (!isRateLimited) {
|
||||
availableAccounts.push({
|
||||
...account,
|
||||
@@ -388,38 +442,46 @@ class UnifiedGeminiScheduler {
|
||||
accountType: 'gemini',
|
||||
priority: parseInt(account.priority) || 50,
|
||||
lastUsedAt: account.lastUsedAt || '0'
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (availableAccounts.length === 0) {
|
||||
throw new Error(`No available accounts in Gemini group ${group.name}`);
|
||||
throw new Error(`No available accounts in Gemini group ${group.name}`)
|
||||
}
|
||||
|
||||
// 使用现有的优先级排序逻辑
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts);
|
||||
const sortedAccounts = this._sortAccountsByPriority(availableAccounts)
|
||||
|
||||
// 选择第一个账户
|
||||
const selectedAccount = sortedAccounts[0];
|
||||
const selectedAccount = sortedAccounts[0]
|
||||
|
||||
// 如果有会话哈希,建立新的映射
|
||||
if (sessionHash) {
|
||||
await this._setSessionMapping(sessionHash, selectedAccount.accountId, selectedAccount.accountType);
|
||||
logger.info(`🎯 Created new sticky session mapping in group: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`);
|
||||
await this._setSessionMapping(
|
||||
sessionHash,
|
||||
selectedAccount.accountId,
|
||||
selectedAccount.accountType
|
||||
)
|
||||
logger.info(
|
||||
`🎯 Created new sticky session mapping in group: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) for session ${sessionHash}`
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`🎯 Selected account from Gemini group ${group.name}: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority}`);
|
||||
|
||||
logger.info(
|
||||
`🎯 Selected account from Gemini group ${group.name}: ${selectedAccount.name} (${selectedAccount.accountId}, ${selectedAccount.accountType}) with priority ${selectedAccount.priority}`
|
||||
)
|
||||
|
||||
return {
|
||||
accountId: selectedAccount.accountId,
|
||||
accountType: selectedAccount.accountType
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`❌ Failed to select account from Gemini group ${groupId}:`, error);
|
||||
throw error;
|
||||
logger.error(`❌ Failed to select account from Gemini group ${groupId}:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new UnifiedGeminiScheduler();
|
||||
module.exports = new UnifiedGeminiScheduler()
|
||||
|
||||
@@ -1,64 +1,63 @@
|
||||
const pricingService = require('../services/pricingService');
|
||||
const pricingService = require('../services/pricingService')
|
||||
|
||||
// Claude模型价格配置 (USD per 1M tokens) - 备用定价
|
||||
const MODEL_PRICING = {
|
||||
// Claude 3.5 Sonnet
|
||||
'claude-3-5-sonnet-20241022': {
|
||||
input: 3.00,
|
||||
output: 15.00,
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cacheWrite: 3.75,
|
||||
cacheRead: 0.30
|
||||
cacheRead: 0.3
|
||||
},
|
||||
'claude-sonnet-4-20250514': {
|
||||
input: 3.00,
|
||||
output: 15.00,
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cacheWrite: 3.75,
|
||||
cacheRead: 0.30
|
||||
cacheRead: 0.3
|
||||
},
|
||||
|
||||
|
||||
// Claude 3.5 Haiku
|
||||
'claude-3-5-haiku-20241022': {
|
||||
input: 0.25,
|
||||
output: 1.25,
|
||||
cacheWrite: 0.30,
|
||||
cacheWrite: 0.3,
|
||||
cacheRead: 0.03
|
||||
},
|
||||
|
||||
|
||||
// Claude 3 Opus
|
||||
'claude-3-opus-20240229': {
|
||||
input: 15.00,
|
||||
output: 75.00,
|
||||
input: 15.0,
|
||||
output: 75.0,
|
||||
cacheWrite: 18.75,
|
||||
cacheRead: 1.50
|
||||
cacheRead: 1.5
|
||||
},
|
||||
|
||||
|
||||
// Claude 3 Sonnet
|
||||
'claude-3-sonnet-20240229': {
|
||||
input: 3.00,
|
||||
output: 15.00,
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cacheWrite: 3.75,
|
||||
cacheRead: 0.30
|
||||
cacheRead: 0.3
|
||||
},
|
||||
|
||||
|
||||
// Claude 3 Haiku
|
||||
'claude-3-haiku-20240307': {
|
||||
input: 0.25,
|
||||
output: 1.25,
|
||||
cacheWrite: 0.30,
|
||||
cacheWrite: 0.3,
|
||||
cacheRead: 0.03
|
||||
},
|
||||
|
||||
|
||||
// 默认定价(用于未知模型)
|
||||
'unknown': {
|
||||
input: 3.00,
|
||||
output: 15.00,
|
||||
unknown: {
|
||||
input: 3.0,
|
||||
output: 15.0,
|
||||
cacheWrite: 3.75,
|
||||
cacheRead: 0.30
|
||||
cacheRead: 0.3
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
class CostCalculator {
|
||||
|
||||
/**
|
||||
* 计算单次请求的费用
|
||||
* @param {Object} usage - 使用量数据
|
||||
@@ -70,16 +69,16 @@ class CostCalculator {
|
||||
* @returns {Object} 费用详情
|
||||
*/
|
||||
static calculateCost(usage, model = 'unknown') {
|
||||
const inputTokens = usage.input_tokens || 0;
|
||||
const outputTokens = usage.output_tokens || 0;
|
||||
const cacheCreateTokens = usage.cache_creation_input_tokens || 0;
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0;
|
||||
|
||||
const inputTokens = usage.input_tokens || 0
|
||||
const outputTokens = usage.output_tokens || 0
|
||||
const cacheCreateTokens = usage.cache_creation_input_tokens || 0
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0
|
||||
|
||||
// 优先使用动态价格服务
|
||||
const pricingData = pricingService.getModelPricing(model);
|
||||
let pricing;
|
||||
let usingDynamicPricing = false;
|
||||
|
||||
const pricingData = pricingService.getModelPricing(model)
|
||||
let pricing
|
||||
let usingDynamicPricing = false
|
||||
|
||||
if (pricingData) {
|
||||
// 转换动态价格格式为内部格式
|
||||
pricing = {
|
||||
@@ -87,21 +86,21 @@ class CostCalculator {
|
||||
output: (pricingData.output_cost_per_token || 0) * 1000000,
|
||||
cacheWrite: (pricingData.cache_creation_input_token_cost || 0) * 1000000,
|
||||
cacheRead: (pricingData.cache_read_input_token_cost || 0) * 1000000
|
||||
};
|
||||
usingDynamicPricing = true;
|
||||
}
|
||||
usingDynamicPricing = true
|
||||
} else {
|
||||
// 回退到静态价格
|
||||
pricing = MODEL_PRICING[model] || MODEL_PRICING['unknown'];
|
||||
pricing = MODEL_PRICING[model] || MODEL_PRICING['unknown']
|
||||
}
|
||||
|
||||
|
||||
// 计算各类型token的费用 (USD)
|
||||
const inputCost = (inputTokens / 1000000) * pricing.input;
|
||||
const outputCost = (outputTokens / 1000000) * pricing.output;
|
||||
const cacheWriteCost = (cacheCreateTokens / 1000000) * pricing.cacheWrite;
|
||||
const cacheReadCost = (cacheReadTokens / 1000000) * pricing.cacheRead;
|
||||
|
||||
const totalCost = inputCost + outputCost + cacheWriteCost + cacheReadCost;
|
||||
|
||||
const inputCost = (inputTokens / 1000000) * pricing.input
|
||||
const outputCost = (outputTokens / 1000000) * pricing.output
|
||||
const cacheWriteCost = (cacheCreateTokens / 1000000) * pricing.cacheWrite
|
||||
const cacheReadCost = (cacheReadTokens / 1000000) * pricing.cacheRead
|
||||
|
||||
const totalCost = inputCost + outputCost + cacheWriteCost + cacheReadCost
|
||||
|
||||
return {
|
||||
model,
|
||||
pricing,
|
||||
@@ -128,9 +127,9 @@ class CostCalculator {
|
||||
cacheRead: this.formatCost(cacheReadCost),
|
||||
total: this.formatCost(totalCost)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 计算聚合使用量的费用
|
||||
* @param {Object} aggregatedUsage - 聚合使用量数据
|
||||
@@ -141,39 +140,41 @@ class CostCalculator {
|
||||
const usage = {
|
||||
input_tokens: aggregatedUsage.inputTokens || aggregatedUsage.totalInputTokens || 0,
|
||||
output_tokens: aggregatedUsage.outputTokens || aggregatedUsage.totalOutputTokens || 0,
|
||||
cache_creation_input_tokens: aggregatedUsage.cacheCreateTokens || aggregatedUsage.totalCacheCreateTokens || 0,
|
||||
cache_read_input_tokens: aggregatedUsage.cacheReadTokens || aggregatedUsage.totalCacheReadTokens || 0
|
||||
};
|
||||
|
||||
return this.calculateCost(usage, model);
|
||||
cache_creation_input_tokens:
|
||||
aggregatedUsage.cacheCreateTokens || aggregatedUsage.totalCacheCreateTokens || 0,
|
||||
cache_read_input_tokens:
|
||||
aggregatedUsage.cacheReadTokens || aggregatedUsage.totalCacheReadTokens || 0
|
||||
}
|
||||
|
||||
return this.calculateCost(usage, model)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取模型定价信息
|
||||
* @param {string} model - 模型名称
|
||||
* @returns {Object} 定价信息
|
||||
*/
|
||||
static getModelPricing(model = 'unknown') {
|
||||
return MODEL_PRICING[model] || MODEL_PRICING['unknown'];
|
||||
return MODEL_PRICING[model] || MODEL_PRICING['unknown']
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 获取所有支持的模型和定价
|
||||
* @returns {Object} 所有模型定价
|
||||
*/
|
||||
static getAllModelPricing() {
|
||||
return { ...MODEL_PRICING };
|
||||
return { ...MODEL_PRICING }
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 验证模型是否支持
|
||||
* @param {string} model - 模型名称
|
||||
* @returns {boolean} 是否支持
|
||||
*/
|
||||
static isModelSupported(model) {
|
||||
return !!MODEL_PRICING[model];
|
||||
return !!MODEL_PRICING[model]
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 格式化费用显示
|
||||
* @param {number} cost - 费用金额
|
||||
@@ -182,14 +183,14 @@ class CostCalculator {
|
||||
*/
|
||||
static formatCost(cost, decimals = 6) {
|
||||
if (cost >= 1) {
|
||||
return `$${cost.toFixed(2)}`;
|
||||
return `$${cost.toFixed(2)}`
|
||||
} else if (cost >= 0.001) {
|
||||
return `$${cost.toFixed(4)}`;
|
||||
return `$${cost.toFixed(4)}`
|
||||
} else {
|
||||
return `$${cost.toFixed(decimals)}`;
|
||||
return `$${cost.toFixed(decimals)}`
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 计算费用节省(使用缓存的节省)
|
||||
* @param {Object} usage - 使用量数据
|
||||
@@ -197,15 +198,15 @@ class CostCalculator {
|
||||
* @returns {Object} 节省信息
|
||||
*/
|
||||
static calculateCacheSavings(usage, model = 'unknown') {
|
||||
const pricing = this.getModelPricing(model);
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0;
|
||||
|
||||
const pricing = this.getModelPricing(model)
|
||||
const cacheReadTokens = usage.cache_read_input_tokens || 0
|
||||
|
||||
// 如果这些token不使用缓存,需要按正常input价格计费
|
||||
const normalCost = (cacheReadTokens / 1000000) * pricing.input;
|
||||
const cacheCost = (cacheReadTokens / 1000000) * pricing.cacheRead;
|
||||
const savings = normalCost - cacheCost;
|
||||
const savingsPercentage = normalCost > 0 ? (savings / normalCost) * 100 : 0;
|
||||
|
||||
const normalCost = (cacheReadTokens / 1000000) * pricing.input
|
||||
const cacheCost = (cacheReadTokens / 1000000) * pricing.cacheRead
|
||||
const savings = normalCost - cacheCost
|
||||
const savingsPercentage = normalCost > 0 ? (savings / normalCost) * 100 : 0
|
||||
|
||||
return {
|
||||
normalCost,
|
||||
cacheCost,
|
||||
@@ -217,8 +218,8 @@ class CostCalculator {
|
||||
savings: this.formatCost(savings),
|
||||
savingsPercentage: `${savingsPercentage.toFixed(1)}%`
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = CostCalculator;
|
||||
module.exports = CostCalculator
|
||||
|
||||
@@ -1,52 +1,58 @@
|
||||
const winston = require('winston');
|
||||
const DailyRotateFile = require('winston-daily-rotate-file');
|
||||
const config = require('../../config/config');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const os = require('os');
|
||||
const winston = require('winston')
|
||||
const DailyRotateFile = require('winston-daily-rotate-file')
|
||||
const config = require('../../config/config')
|
||||
const path = require('path')
|
||||
const fs = require('fs')
|
||||
const os = require('os')
|
||||
|
||||
// 安全的 JSON 序列化函数,处理循环引用
|
||||
const safeStringify = (obj, maxDepth = 3) => {
|
||||
const seen = new WeakSet();
|
||||
|
||||
const seen = new WeakSet()
|
||||
|
||||
const replacer = (key, value, depth = 0) => {
|
||||
if (depth > maxDepth) return '[Max Depth Reached]';
|
||||
|
||||
if (depth > maxDepth) {
|
||||
return '[Max Depth Reached]'
|
||||
}
|
||||
|
||||
if (value !== null && typeof value === 'object') {
|
||||
if (seen.has(value)) {
|
||||
return '[Circular Reference]';
|
||||
return '[Circular Reference]'
|
||||
}
|
||||
seen.add(value);
|
||||
|
||||
seen.add(value)
|
||||
|
||||
// 过滤掉常见的循环引用对象
|
||||
if (value.constructor) {
|
||||
const constructorName = value.constructor.name;
|
||||
if (['Socket', 'TLSSocket', 'HTTPParser', 'IncomingMessage', 'ServerResponse'].includes(constructorName)) {
|
||||
return `[${constructorName} Object]`;
|
||||
const constructorName = value.constructor.name
|
||||
if (
|
||||
['Socket', 'TLSSocket', 'HTTPParser', 'IncomingMessage', 'ServerResponse'].includes(
|
||||
constructorName
|
||||
)
|
||||
) {
|
||||
return `[${constructorName} Object]`
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 递归处理对象属性
|
||||
if (Array.isArray(value)) {
|
||||
return value.map((item, index) => replacer(index, item, depth + 1));
|
||||
return value.map((item, index) => replacer(index, item, depth + 1))
|
||||
} else {
|
||||
const result = {};
|
||||
const result = {}
|
||||
for (const [k, v] of Object.entries(value)) {
|
||||
result[k] = replacer(k, v, depth + 1);
|
||||
result[k] = replacer(k, v, depth + 1)
|
||||
}
|
||||
return result;
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return value;
|
||||
};
|
||||
|
||||
try {
|
||||
return JSON.stringify(replacer('', obj));
|
||||
} catch (error) {
|
||||
return JSON.stringify({ error: 'Failed to serialize object', message: error.message });
|
||||
|
||||
return value
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
return JSON.stringify(replacer('', obj))
|
||||
} catch (error) {
|
||||
return JSON.stringify({ error: 'Failed to serialize object', message: error.message })
|
||||
}
|
||||
}
|
||||
|
||||
// 📝 增强的日志格式
|
||||
const createLogFormat = (colorize = false) => {
|
||||
@@ -54,12 +60,12 @@ const createLogFormat = (colorize = false) => {
|
||||
winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }),
|
||||
winston.format.errors({ stack: true }),
|
||||
winston.format.metadata({ fillExcept: ['message', 'level', 'timestamp', 'stack'] })
|
||||
];
|
||||
|
||||
]
|
||||
|
||||
if (colorize) {
|
||||
formats.push(winston.format.colorize());
|
||||
formats.push(winston.format.colorize())
|
||||
}
|
||||
|
||||
|
||||
formats.push(
|
||||
winston.format.printf(({ level, message, timestamp, stack, metadata, ...rest }) => {
|
||||
const emoji = {
|
||||
@@ -68,39 +74,39 @@ const createLogFormat = (colorize = false) => {
|
||||
info: 'ℹ️ ',
|
||||
debug: '🐛',
|
||||
verbose: '📝'
|
||||
};
|
||||
|
||||
let logMessage = `${emoji[level] || '📝'} [${timestamp}] ${level.toUpperCase()}: ${message}`;
|
||||
|
||||
}
|
||||
|
||||
let logMessage = `${emoji[level] || '📝'} [${timestamp}] ${level.toUpperCase()}: ${message}`
|
||||
|
||||
// 添加元数据
|
||||
if (metadata && Object.keys(metadata).length > 0) {
|
||||
logMessage += ` | ${safeStringify(metadata)}`;
|
||||
logMessage += ` | ${safeStringify(metadata)}`
|
||||
}
|
||||
|
||||
// 添加其他属性
|
||||
const additionalData = { ...rest };
|
||||
delete additionalData.level;
|
||||
delete additionalData.message;
|
||||
delete additionalData.timestamp;
|
||||
delete additionalData.stack;
|
||||
|
||||
if (Object.keys(additionalData).length > 0) {
|
||||
logMessage += ` | ${safeStringify(additionalData)}`;
|
||||
}
|
||||
|
||||
return stack ? `${logMessage}\n${stack}` : logMessage;
|
||||
})
|
||||
);
|
||||
|
||||
return winston.format.combine(...formats);
|
||||
};
|
||||
|
||||
const logFormat = createLogFormat(false);
|
||||
const consoleFormat = createLogFormat(true);
|
||||
// 添加其他属性
|
||||
const additionalData = { ...rest }
|
||||
delete additionalData.level
|
||||
delete additionalData.message
|
||||
delete additionalData.timestamp
|
||||
delete additionalData.stack
|
||||
|
||||
if (Object.keys(additionalData).length > 0) {
|
||||
logMessage += ` | ${safeStringify(additionalData)}`
|
||||
}
|
||||
|
||||
return stack ? `${logMessage}\n${stack}` : logMessage
|
||||
})
|
||||
)
|
||||
|
||||
return winston.format.combine(...formats)
|
||||
}
|
||||
|
||||
const logFormat = createLogFormat(false)
|
||||
const consoleFormat = createLogFormat(true)
|
||||
|
||||
// 📁 确保日志目录存在并设置权限
|
||||
if (!fs.existsSync(config.logging.dirname)) {
|
||||
fs.mkdirSync(config.logging.dirname, { recursive: true, mode: 0o755 });
|
||||
fs.mkdirSync(config.logging.dirname, { recursive: true, mode: 0o755 })
|
||||
}
|
||||
|
||||
// 🔄 增强的日志轮转配置
|
||||
@@ -113,40 +119,38 @@ const createRotateTransport = (filename, level = null) => {
|
||||
maxFiles: config.logging.maxFiles,
|
||||
auditFile: path.join(config.logging.dirname, `.${filename.replace('%DATE%', 'audit')}.json`),
|
||||
format: logFormat
|
||||
});
|
||||
|
||||
})
|
||||
|
||||
if (level) {
|
||||
transport.level = level;
|
||||
transport.level = level
|
||||
}
|
||||
|
||||
|
||||
// 监听轮转事件
|
||||
transport.on('rotate', (oldFilename, newFilename) => {
|
||||
console.log(`📦 Log rotated: ${oldFilename} -> ${newFilename}`);
|
||||
});
|
||||
|
||||
transport.on('new', (newFilename) => {
|
||||
console.log(`📄 New log file created: ${newFilename}`);
|
||||
});
|
||||
|
||||
transport.on('archive', (zipFilename) => {
|
||||
console.log(`🗜️ Log archived: ${zipFilename}`);
|
||||
});
|
||||
|
||||
return transport;
|
||||
};
|
||||
console.log(`📦 Log rotated: ${oldFilename} -> ${newFilename}`)
|
||||
})
|
||||
|
||||
const dailyRotateFileTransport = createRotateTransport('claude-relay-%DATE%.log');
|
||||
const errorFileTransport = createRotateTransport('claude-relay-error-%DATE%.log', 'error');
|
||||
transport.on('new', (newFilename) => {
|
||||
console.log(`📄 New log file created: ${newFilename}`)
|
||||
})
|
||||
|
||||
transport.on('archive', (zipFilename) => {
|
||||
console.log(`🗜️ Log archived: ${zipFilename}`)
|
||||
})
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
const dailyRotateFileTransport = createRotateTransport('claude-relay-%DATE%.log')
|
||||
const errorFileTransport = createRotateTransport('claude-relay-error-%DATE%.log', 'error')
|
||||
|
||||
// 🔒 创建专门的安全日志记录器
|
||||
const securityLogger = winston.createLogger({
|
||||
level: 'warn',
|
||||
format: logFormat,
|
||||
transports: [
|
||||
createRotateTransport('claude-relay-security-%DATE%.log', 'warn')
|
||||
],
|
||||
transports: [createRotateTransport('claude-relay-security-%DATE%.log', 'warn')],
|
||||
silent: false
|
||||
});
|
||||
})
|
||||
|
||||
// 🌟 增强的 Winston logger
|
||||
const logger = winston.createLogger({
|
||||
@@ -156,7 +160,7 @@ const logger = winston.createLogger({
|
||||
// 📄 文件输出
|
||||
dailyRotateFileTransport,
|
||||
errorFileTransport,
|
||||
|
||||
|
||||
// 🖥️ 控制台输出
|
||||
new winston.transports.Console({
|
||||
format: consoleFormat,
|
||||
@@ -164,10 +168,10 @@ const logger = winston.createLogger({
|
||||
handleRejections: false
|
||||
})
|
||||
],
|
||||
|
||||
|
||||
// 🚨 异常处理
|
||||
exceptionHandlers: [
|
||||
new winston.transports.File({
|
||||
new winston.transports.File({
|
||||
filename: path.join(config.logging.dirname, 'exceptions.log'),
|
||||
format: logFormat,
|
||||
maxsize: 10485760, // 10MB
|
||||
@@ -177,10 +181,10 @@ const logger = winston.createLogger({
|
||||
format: consoleFormat
|
||||
})
|
||||
],
|
||||
|
||||
|
||||
// 🔄 未捕获异常处理
|
||||
rejectionHandlers: [
|
||||
new winston.transports.File({
|
||||
new winston.transports.File({
|
||||
filename: path.join(config.logging.dirname, 'rejections.log'),
|
||||
format: logFormat,
|
||||
maxsize: 10485760, // 10MB
|
||||
@@ -190,24 +194,24 @@ const logger = winston.createLogger({
|
||||
format: consoleFormat
|
||||
})
|
||||
],
|
||||
|
||||
|
||||
// 防止进程退出
|
||||
exitOnError: false
|
||||
});
|
||||
})
|
||||
|
||||
// 🎯 增强的自定义方法
|
||||
logger.success = (message, metadata = {}) => {
|
||||
logger.info(`✅ ${message}`, { type: 'success', ...metadata });
|
||||
};
|
||||
logger.info(`✅ ${message}`, { type: 'success', ...metadata })
|
||||
}
|
||||
|
||||
logger.start = (message, metadata = {}) => {
|
||||
logger.info(`🚀 ${message}`, { type: 'startup', ...metadata });
|
||||
};
|
||||
logger.info(`🚀 ${message}`, { type: 'startup', ...metadata })
|
||||
}
|
||||
|
||||
logger.request = (method, url, status, duration, metadata = {}) => {
|
||||
const emoji = status >= 400 ? '🔴' : status >= 300 ? '🟡' : '🟢';
|
||||
const level = status >= 400 ? 'error' : status >= 300 ? 'warn' : 'info';
|
||||
|
||||
const emoji = status >= 400 ? '🔴' : status >= 300 ? '🟡' : '🟢'
|
||||
const level = status >= 400 ? 'error' : status >= 300 ? 'warn' : 'info'
|
||||
|
||||
logger[level](`${emoji} ${method} ${url} - ${status} (${duration}ms)`, {
|
||||
type: 'request',
|
||||
method,
|
||||
@@ -215,12 +219,12 @@ logger.request = (method, url, status, duration, metadata = {}) => {
|
||||
status,
|
||||
duration,
|
||||
...metadata
|
||||
});
|
||||
};
|
||||
})
|
||||
}
|
||||
|
||||
logger.api = (message, metadata = {}) => {
|
||||
logger.info(`🔗 ${message}`, { type: 'api', ...metadata });
|
||||
};
|
||||
logger.info(`🔗 ${message}`, { type: 'api', ...metadata })
|
||||
}
|
||||
|
||||
logger.security = (message, metadata = {}) => {
|
||||
const securityData = {
|
||||
@@ -229,99 +233,99 @@ logger.security = (message, metadata = {}) => {
|
||||
pid: process.pid,
|
||||
hostname: os.hostname(),
|
||||
...metadata
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
// 记录到主日志
|
||||
logger.warn(`🔒 ${message}`, securityData);
|
||||
|
||||
logger.warn(`🔒 ${message}`, securityData)
|
||||
|
||||
// 记录到专门的安全日志文件
|
||||
try {
|
||||
securityLogger.warn(`🔒 ${message}`, securityData);
|
||||
securityLogger.warn(`🔒 ${message}`, securityData)
|
||||
} catch (error) {
|
||||
// 如果安全日志文件不可用,只记录到主日志
|
||||
console.warn('Security logger not available:', error.message);
|
||||
console.warn('Security logger not available:', error.message)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
logger.database = (message, metadata = {}) => {
|
||||
logger.debug(`💾 ${message}`, { type: 'database', ...metadata });
|
||||
};
|
||||
logger.debug(`💾 ${message}`, { type: 'database', ...metadata })
|
||||
}
|
||||
|
||||
logger.performance = (message, metadata = {}) => {
|
||||
logger.info(`⚡ ${message}`, { type: 'performance', ...metadata });
|
||||
};
|
||||
logger.info(`⚡ ${message}`, { type: 'performance', ...metadata })
|
||||
}
|
||||
|
||||
logger.audit = (message, metadata = {}) => {
|
||||
logger.info(`📋 ${message}`, {
|
||||
logger.info(`📋 ${message}`, {
|
||||
type: 'audit',
|
||||
timestamp: new Date().toISOString(),
|
||||
pid: process.pid,
|
||||
...metadata
|
||||
});
|
||||
};
|
||||
...metadata
|
||||
})
|
||||
}
|
||||
|
||||
// 🔧 性能监控方法
|
||||
logger.timer = (label) => {
|
||||
const start = Date.now();
|
||||
const start = Date.now()
|
||||
return {
|
||||
end: (message = '', metadata = {}) => {
|
||||
const duration = Date.now() - start;
|
||||
logger.performance(`${label} ${message}`, { duration, ...metadata });
|
||||
return duration;
|
||||
const duration = Date.now() - start
|
||||
logger.performance(`${label} ${message}`, { duration, ...metadata })
|
||||
return duration
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 📊 日志统计
|
||||
logger.stats = {
|
||||
requests: 0,
|
||||
errors: 0,
|
||||
warnings: 0
|
||||
};
|
||||
}
|
||||
|
||||
// 重写原始方法以统计
|
||||
const originalError = logger.error;
|
||||
const originalWarn = logger.warn;
|
||||
const originalInfo = logger.info;
|
||||
const originalError = logger.error
|
||||
const originalWarn = logger.warn
|
||||
const originalInfo = logger.info
|
||||
|
||||
logger.error = function(message, ...args) {
|
||||
logger.stats.errors++;
|
||||
return originalError.call(this, message, ...args);
|
||||
};
|
||||
logger.error = function (message, ...args) {
|
||||
logger.stats.errors++
|
||||
return originalError.call(this, message, ...args)
|
||||
}
|
||||
|
||||
logger.warn = function(message, ...args) {
|
||||
logger.stats.warnings++;
|
||||
return originalWarn.call(this, message, ...args);
|
||||
};
|
||||
logger.warn = function (message, ...args) {
|
||||
logger.stats.warnings++
|
||||
return originalWarn.call(this, message, ...args)
|
||||
}
|
||||
|
||||
logger.info = function(message, ...args) {
|
||||
logger.info = function (message, ...args) {
|
||||
// 检查是否是请求类型的日志
|
||||
if (args.length > 0 && typeof args[0] === 'object' && args[0].type === 'request') {
|
||||
logger.stats.requests++;
|
||||
logger.stats.requests++
|
||||
}
|
||||
return originalInfo.call(this, message, ...args);
|
||||
};
|
||||
return originalInfo.call(this, message, ...args)
|
||||
}
|
||||
|
||||
// 📈 获取日志统计
|
||||
logger.getStats = () => ({ ...logger.stats });
|
||||
logger.getStats = () => ({ ...logger.stats })
|
||||
|
||||
// 🧹 清理统计
|
||||
logger.resetStats = () => {
|
||||
logger.stats.requests = 0;
|
||||
logger.stats.errors = 0;
|
||||
logger.stats.warnings = 0;
|
||||
};
|
||||
logger.stats.requests = 0
|
||||
logger.stats.errors = 0
|
||||
logger.stats.warnings = 0
|
||||
}
|
||||
|
||||
// 📡 健康检查
|
||||
logger.healthCheck = () => {
|
||||
try {
|
||||
const testMessage = 'Logger health check';
|
||||
logger.debug(testMessage);
|
||||
return { healthy: true, timestamp: new Date().toISOString() };
|
||||
const testMessage = 'Logger health check'
|
||||
logger.debug(testMessage)
|
||||
return { healthy: true, timestamp: new Date().toISOString() }
|
||||
} catch (error) {
|
||||
return { healthy: false, error: error.message, timestamp: new Date().toISOString() };
|
||||
return { healthy: false, error: error.message, timestamp: new Date().toISOString() }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 🎬 启动日志记录系统
|
||||
logger.start('Logger initialized', {
|
||||
@@ -330,6 +334,6 @@ logger.start('Logger initialized', {
|
||||
maxSize: config.logging.maxSize,
|
||||
maxFiles: config.logging.maxFiles,
|
||||
envOverride: process.env.LOG_LEVEL ? true : false
|
||||
});
|
||||
})
|
||||
|
||||
module.exports = logger;
|
||||
module.exports = logger
|
||||
|
||||
@@ -3,27 +3,27 @@
|
||||
* 基于claude-code-login.js中的OAuth流程实现
|
||||
*/
|
||||
|
||||
const crypto = require('crypto');
|
||||
const { SocksProxyAgent } = require('socks-proxy-agent');
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent');
|
||||
const axios = require('axios');
|
||||
const logger = require('./logger');
|
||||
const crypto = require('crypto')
|
||||
const { SocksProxyAgent } = require('socks-proxy-agent')
|
||||
const { HttpsProxyAgent } = require('https-proxy-agent')
|
||||
const axios = require('axios')
|
||||
const logger = require('./logger')
|
||||
|
||||
// OAuth 配置常量 - 从claude-code-login.js提取
|
||||
const OAUTH_CONFIG = {
|
||||
AUTHORIZE_URL: 'https://claude.ai/oauth/authorize',
|
||||
TOKEN_URL: 'https://console.anthropic.com/v1/oauth/token',
|
||||
CLIENT_ID: '9d1c250a-e61b-44d9-88ed-5944d1962f5e',
|
||||
REDIRECT_URI: 'https://console.anthropic.com/oauth/code/callback',
|
||||
SCOPES: 'org:create_api_key user:profile user:inference'
|
||||
};
|
||||
AUTHORIZE_URL: 'https://claude.ai/oauth/authorize',
|
||||
TOKEN_URL: 'https://console.anthropic.com/v1/oauth/token',
|
||||
CLIENT_ID: '9d1c250a-e61b-44d9-88ed-5944d1962f5e',
|
||||
REDIRECT_URI: 'https://console.anthropic.com/oauth/code/callback',
|
||||
SCOPES: 'org:create_api_key user:profile user:inference'
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成随机的 state 参数
|
||||
* @returns {string} 随机生成的 state (64字符hex)
|
||||
*/
|
||||
function generateState() {
|
||||
return crypto.randomBytes(32).toString('hex');
|
||||
return crypto.randomBytes(32).toString('hex')
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -31,7 +31,7 @@ function generateState() {
|
||||
* @returns {string} base64url 编码的随机字符串
|
||||
*/
|
||||
function generateCodeVerifier() {
|
||||
return crypto.randomBytes(32).toString('base64url');
|
||||
return crypto.randomBytes(32).toString('base64url')
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -40,9 +40,7 @@ function generateCodeVerifier() {
|
||||
* @returns {string} SHA256 哈希后的 base64url 编码字符串
|
||||
*/
|
||||
function generateCodeChallenge(codeVerifier) {
|
||||
return crypto.createHash('sha256')
|
||||
.update(codeVerifier)
|
||||
.digest('base64url');
|
||||
return crypto.createHash('sha256').update(codeVerifier).digest('base64url')
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,18 +50,18 @@ function generateCodeChallenge(codeVerifier) {
|
||||
* @returns {string} 完整的授权 URL
|
||||
*/
|
||||
function generateAuthUrl(codeChallenge, state) {
|
||||
const params = new URLSearchParams({
|
||||
code: 'true',
|
||||
client_id: OAUTH_CONFIG.CLIENT_ID,
|
||||
response_type: 'code',
|
||||
redirect_uri: OAUTH_CONFIG.REDIRECT_URI,
|
||||
scope: OAUTH_CONFIG.SCOPES,
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
state: state
|
||||
});
|
||||
const params = new URLSearchParams({
|
||||
code: 'true',
|
||||
client_id: OAUTH_CONFIG.CLIENT_ID,
|
||||
response_type: 'code',
|
||||
redirect_uri: OAUTH_CONFIG.REDIRECT_URI,
|
||||
scope: OAUTH_CONFIG.SCOPES,
|
||||
code_challenge: codeChallenge,
|
||||
code_challenge_method: 'S256',
|
||||
state
|
||||
})
|
||||
|
||||
return `${OAUTH_CONFIG.AUTHORIZE_URL}?${params.toString()}`;
|
||||
return `${OAUTH_CONFIG.AUTHORIZE_URL}?${params.toString()}`
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -71,18 +69,18 @@ function generateAuthUrl(codeChallenge, state) {
|
||||
* @returns {{authUrl: string, codeVerifier: string, state: string, codeChallenge: string}}
|
||||
*/
|
||||
function generateOAuthParams() {
|
||||
const state = generateState();
|
||||
const codeVerifier = generateCodeVerifier();
|
||||
const codeChallenge = generateCodeChallenge(codeVerifier);
|
||||
|
||||
const authUrl = generateAuthUrl(codeChallenge, state);
|
||||
|
||||
return {
|
||||
authUrl,
|
||||
codeVerifier,
|
||||
state,
|
||||
codeChallenge
|
||||
};
|
||||
const state = generateState()
|
||||
const codeVerifier = generateCodeVerifier()
|
||||
const codeChallenge = generateCodeChallenge(codeVerifier)
|
||||
|
||||
const authUrl = generateAuthUrl(codeChallenge, state)
|
||||
|
||||
return {
|
||||
authUrl,
|
||||
codeVerifier,
|
||||
state,
|
||||
codeChallenge
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -91,25 +89,31 @@ function generateOAuthParams() {
|
||||
* @returns {object|null} 代理agent或null
|
||||
*/
|
||||
function createProxyAgent(proxyConfig) {
|
||||
if (!proxyConfig) {
|
||||
return null;
|
||||
}
|
||||
if (!proxyConfig) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
if (proxyConfig.type === 'socks5') {
|
||||
const auth = proxyConfig.username && proxyConfig.password ? `${proxyConfig.username}:${proxyConfig.password}@` : '';
|
||||
const socksUrl = `socks5://${auth}${proxyConfig.host}:${proxyConfig.port}`;
|
||||
return new SocksProxyAgent(socksUrl);
|
||||
} else if (proxyConfig.type === 'http' || proxyConfig.type === 'https') {
|
||||
const auth = proxyConfig.username && proxyConfig.password ? `${proxyConfig.username}:${proxyConfig.password}@` : '';
|
||||
const httpUrl = `${proxyConfig.type}://${auth}${proxyConfig.host}:${proxyConfig.port}`;
|
||||
return new HttpsProxyAgent(httpUrl);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('⚠️ Invalid proxy configuration:', error);
|
||||
try {
|
||||
if (proxyConfig.type === 'socks5') {
|
||||
const auth =
|
||||
proxyConfig.username && proxyConfig.password
|
||||
? `${proxyConfig.username}:${proxyConfig.password}@`
|
||||
: ''
|
||||
const socksUrl = `socks5://${auth}${proxyConfig.host}:${proxyConfig.port}`
|
||||
return new SocksProxyAgent(socksUrl)
|
||||
} else if (proxyConfig.type === 'http' || proxyConfig.type === 'https') {
|
||||
const auth =
|
||||
proxyConfig.username && proxyConfig.password
|
||||
? `${proxyConfig.username}:${proxyConfig.password}@`
|
||||
: ''
|
||||
const httpUrl = `${proxyConfig.type}://${auth}${proxyConfig.host}:${proxyConfig.port}`
|
||||
return new HttpsProxyAgent(httpUrl)
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('⚠️ Invalid proxy configuration:', error)
|
||||
}
|
||||
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -121,110 +125,110 @@ function createProxyAgent(proxyConfig) {
|
||||
* @returns {Promise<object>} Claude格式的token响应
|
||||
*/
|
||||
async function exchangeCodeForTokens(authorizationCode, codeVerifier, state, proxyConfig = null) {
|
||||
// 清理授权码,移除URL片段
|
||||
const cleanedCode = authorizationCode.split('#')[0]?.split('&')[0] ?? authorizationCode;
|
||||
|
||||
const params = {
|
||||
grant_type: 'authorization_code',
|
||||
client_id: OAUTH_CONFIG.CLIENT_ID,
|
||||
code: cleanedCode,
|
||||
redirect_uri: OAUTH_CONFIG.REDIRECT_URI,
|
||||
code_verifier: codeVerifier,
|
||||
state: state
|
||||
};
|
||||
// 清理授权码,移除URL片段
|
||||
const cleanedCode = authorizationCode.split('#')[0]?.split('&')[0] ?? authorizationCode
|
||||
|
||||
// 创建代理agent
|
||||
const agent = createProxyAgent(proxyConfig);
|
||||
const params = {
|
||||
grant_type: 'authorization_code',
|
||||
client_id: OAUTH_CONFIG.CLIENT_ID,
|
||||
code: cleanedCode,
|
||||
redirect_uri: OAUTH_CONFIG.REDIRECT_URI,
|
||||
code_verifier: codeVerifier,
|
||||
state
|
||||
}
|
||||
|
||||
try {
|
||||
logger.debug('🔄 Attempting OAuth token exchange', {
|
||||
url: OAUTH_CONFIG.TOKEN_URL,
|
||||
codeLength: cleanedCode.length,
|
||||
codePrefix: cleanedCode.substring(0, 10) + '...',
|
||||
hasProxy: !!proxyConfig,
|
||||
proxyType: proxyConfig?.type || 'none'
|
||||
});
|
||||
// 创建代理agent
|
||||
const agent = createProxyAgent(proxyConfig)
|
||||
|
||||
const response = await axios.post(OAUTH_CONFIG.TOKEN_URL, params, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'claude-cli/1.0.56 (external, cli)',
|
||||
'Accept': 'application/json, text/plain, */*',
|
||||
'Accept-Language': 'en-US,en;q=0.9',
|
||||
'Referer': 'https://claude.ai/',
|
||||
'Origin': 'https://claude.ai'
|
||||
},
|
||||
httpsAgent: agent,
|
||||
timeout: 30000
|
||||
});
|
||||
try {
|
||||
logger.debug('🔄 Attempting OAuth token exchange', {
|
||||
url: OAUTH_CONFIG.TOKEN_URL,
|
||||
codeLength: cleanedCode.length,
|
||||
codePrefix: `${cleanedCode.substring(0, 10)}...`,
|
||||
hasProxy: !!proxyConfig,
|
||||
proxyType: proxyConfig?.type || 'none'
|
||||
})
|
||||
|
||||
logger.success('✅ OAuth token exchange successful', {
|
||||
status: response.status,
|
||||
hasAccessToken: !!response.data?.access_token,
|
||||
hasRefreshToken: !!response.data?.refresh_token,
|
||||
scopes: response.data?.scope
|
||||
});
|
||||
const response = await axios.post(OAUTH_CONFIG.TOKEN_URL, params, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'claude-cli/1.0.56 (external, cli)',
|
||||
Accept: 'application/json, text/plain, */*',
|
||||
'Accept-Language': 'en-US,en;q=0.9',
|
||||
Referer: 'https://claude.ai/',
|
||||
Origin: 'https://claude.ai'
|
||||
},
|
||||
httpsAgent: agent,
|
||||
timeout: 30000
|
||||
})
|
||||
|
||||
const data = response.data;
|
||||
|
||||
// 返回Claude格式的token数据
|
||||
return {
|
||||
accessToken: data.access_token,
|
||||
refreshToken: data.refresh_token,
|
||||
expiresAt: (Math.floor(Date.now() / 1000) + data.expires_in) * 1000,
|
||||
scopes: data.scope ? data.scope.split(' ') : ['user:inference', 'user:profile'],
|
||||
isMax: true
|
||||
};
|
||||
} catch (error) {
|
||||
// 处理axios错误响应
|
||||
if (error.response) {
|
||||
// 服务器返回了错误状态码
|
||||
const status = error.response.status;
|
||||
const errorData = error.response.data;
|
||||
|
||||
logger.error('❌ OAuth token exchange failed with server error', {
|
||||
status: status,
|
||||
statusText: error.response.statusText,
|
||||
headers: error.response.headers,
|
||||
data: errorData,
|
||||
codeLength: cleanedCode.length,
|
||||
codePrefix: cleanedCode.substring(0, 10) + '...'
|
||||
});
|
||||
|
||||
// 尝试从错误响应中提取有用信息
|
||||
let errorMessage = `HTTP ${status}`;
|
||||
|
||||
if (errorData) {
|
||||
if (typeof errorData === 'string') {
|
||||
errorMessage += `: ${errorData}`;
|
||||
} else if (errorData.error) {
|
||||
errorMessage += `: ${errorData.error}`;
|
||||
if (errorData.error_description) {
|
||||
errorMessage += ` - ${errorData.error_description}`;
|
||||
}
|
||||
} else {
|
||||
errorMessage += `: ${JSON.stringify(errorData)}`;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Token exchange failed: ${errorMessage}`);
|
||||
} else if (error.request) {
|
||||
// 请求被发送但没有收到响应
|
||||
logger.error('❌ OAuth token exchange failed with network error', {
|
||||
message: error.message,
|
||||
code: error.code,
|
||||
hasProxy: !!proxyConfig
|
||||
});
|
||||
throw new Error('Token exchange failed: No response from server (network error or timeout)');
|
||||
} else {
|
||||
// 其他错误
|
||||
logger.error('❌ OAuth token exchange failed with unknown error', {
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
});
|
||||
throw new Error(`Token exchange failed: ${error.message}`);
|
||||
}
|
||||
logger.success('✅ OAuth token exchange successful', {
|
||||
status: response.status,
|
||||
hasAccessToken: !!response.data?.access_token,
|
||||
hasRefreshToken: !!response.data?.refresh_token,
|
||||
scopes: response.data?.scope
|
||||
})
|
||||
|
||||
const { data } = response
|
||||
|
||||
// 返回Claude格式的token数据
|
||||
return {
|
||||
accessToken: data.access_token,
|
||||
refreshToken: data.refresh_token,
|
||||
expiresAt: (Math.floor(Date.now() / 1000) + data.expires_in) * 1000,
|
||||
scopes: data.scope ? data.scope.split(' ') : ['user:inference', 'user:profile'],
|
||||
isMax: true
|
||||
}
|
||||
} catch (error) {
|
||||
// 处理axios错误响应
|
||||
if (error.response) {
|
||||
// 服务器返回了错误状态码
|
||||
const { status } = error.response
|
||||
const errorData = error.response.data
|
||||
|
||||
logger.error('❌ OAuth token exchange failed with server error', {
|
||||
status,
|
||||
statusText: error.response.statusText,
|
||||
headers: error.response.headers,
|
||||
data: errorData,
|
||||
codeLength: cleanedCode.length,
|
||||
codePrefix: `${cleanedCode.substring(0, 10)}...`
|
||||
})
|
||||
|
||||
// 尝试从错误响应中提取有用信息
|
||||
let errorMessage = `HTTP ${status}`
|
||||
|
||||
if (errorData) {
|
||||
if (typeof errorData === 'string') {
|
||||
errorMessage += `: ${errorData}`
|
||||
} else if (errorData.error) {
|
||||
errorMessage += `: ${errorData.error}`
|
||||
if (errorData.error_description) {
|
||||
errorMessage += ` - ${errorData.error_description}`
|
||||
}
|
||||
} else {
|
||||
errorMessage += `: ${JSON.stringify(errorData)}`
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Token exchange failed: ${errorMessage}`)
|
||||
} else if (error.request) {
|
||||
// 请求被发送但没有收到响应
|
||||
logger.error('❌ OAuth token exchange failed with network error', {
|
||||
message: error.message,
|
||||
code: error.code,
|
||||
hasProxy: !!proxyConfig
|
||||
})
|
||||
throw new Error('Token exchange failed: No response from server (network error or timeout)')
|
||||
} else {
|
||||
// 其他错误
|
||||
logger.error('❌ OAuth token exchange failed with unknown error', {
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
})
|
||||
throw new Error(`Token exchange failed: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -233,47 +237,47 @@ async function exchangeCodeForTokens(authorizationCode, codeVerifier, state, pro
|
||||
* @returns {string} 授权码
|
||||
*/
|
||||
function parseCallbackUrl(input) {
|
||||
if (!input || typeof input !== 'string') {
|
||||
throw new Error('请提供有效的授权码或回调 URL');
|
||||
}
|
||||
if (!input || typeof input !== 'string') {
|
||||
throw new Error('请提供有效的授权码或回调 URL')
|
||||
}
|
||||
|
||||
const trimmedInput = input.trim();
|
||||
|
||||
// 情况1: 尝试作为完整URL解析
|
||||
if (trimmedInput.startsWith('http://') || trimmedInput.startsWith('https://')) {
|
||||
try {
|
||||
const urlObj = new URL(trimmedInput);
|
||||
const authorizationCode = urlObj.searchParams.get('code');
|
||||
const trimmedInput = input.trim()
|
||||
|
||||
if (!authorizationCode) {
|
||||
throw new Error('回调 URL 中未找到授权码 (code 参数)');
|
||||
}
|
||||
// 情况1: 尝试作为完整URL解析
|
||||
if (trimmedInput.startsWith('http://') || trimmedInput.startsWith('https://')) {
|
||||
try {
|
||||
const urlObj = new URL(trimmedInput)
|
||||
const authorizationCode = urlObj.searchParams.get('code')
|
||||
|
||||
return authorizationCode;
|
||||
} catch (error) {
|
||||
if (error.message.includes('回调 URL 中未找到授权码')) {
|
||||
throw error;
|
||||
}
|
||||
throw new Error('无效的 URL 格式,请检查回调 URL 是否正确');
|
||||
}
|
||||
if (!authorizationCode) {
|
||||
throw new Error('回调 URL 中未找到授权码 (code 参数)')
|
||||
}
|
||||
|
||||
return authorizationCode
|
||||
} catch (error) {
|
||||
if (error.message.includes('回调 URL 中未找到授权码')) {
|
||||
throw error
|
||||
}
|
||||
throw new Error('无效的 URL 格式,请检查回调 URL 是否正确')
|
||||
}
|
||||
|
||||
// 情况2: 直接的授权码(可能包含URL fragments)
|
||||
// 参考claude-code-login.js的处理方式:移除URL fragments和参数
|
||||
const cleanedCode = trimmedInput.split('#')[0]?.split('&')[0] ?? trimmedInput;
|
||||
|
||||
// 验证授权码格式(Claude的授权码通常是base64url格式)
|
||||
if (!cleanedCode || cleanedCode.length < 10) {
|
||||
throw new Error('授权码格式无效,请确保复制了完整的 Authorization Code');
|
||||
}
|
||||
|
||||
// 基本格式验证:授权码应该只包含字母、数字、下划线、连字符
|
||||
const validCodePattern = /^[A-Za-z0-9_-]+$/;
|
||||
if (!validCodePattern.test(cleanedCode)) {
|
||||
throw new Error('授权码包含无效字符,请检查是否复制了正确的 Authorization Code');
|
||||
}
|
||||
|
||||
return cleanedCode;
|
||||
}
|
||||
|
||||
// 情况2: 直接的授权码(可能包含URL fragments)
|
||||
// 参考claude-code-login.js的处理方式:移除URL fragments和参数
|
||||
const cleanedCode = trimmedInput.split('#')[0]?.split('&')[0] ?? trimmedInput
|
||||
|
||||
// 验证授权码格式(Claude的授权码通常是base64url格式)
|
||||
if (!cleanedCode || cleanedCode.length < 10) {
|
||||
throw new Error('授权码格式无效,请确保复制了完整的 Authorization Code')
|
||||
}
|
||||
|
||||
// 基本格式验证:授权码应该只包含字母、数字、下划线、连字符
|
||||
const validCodePattern = /^[A-Za-z0-9_-]+$/
|
||||
if (!validCodePattern.test(cleanedCode)) {
|
||||
throw new Error('授权码包含无效字符,请检查是否复制了正确的 Authorization Code')
|
||||
}
|
||||
|
||||
return cleanedCode
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -282,26 +286,26 @@ function parseCallbackUrl(input) {
|
||||
* @returns {object} claudeAiOauth格式的数据
|
||||
*/
|
||||
function formatClaudeCredentials(tokenData) {
|
||||
return {
|
||||
claudeAiOauth: {
|
||||
accessToken: tokenData.accessToken,
|
||||
refreshToken: tokenData.refreshToken,
|
||||
expiresAt: tokenData.expiresAt,
|
||||
scopes: tokenData.scopes,
|
||||
isMax: tokenData.isMax
|
||||
}
|
||||
};
|
||||
return {
|
||||
claudeAiOauth: {
|
||||
accessToken: tokenData.accessToken,
|
||||
refreshToken: tokenData.refreshToken,
|
||||
expiresAt: tokenData.expiresAt,
|
||||
scopes: tokenData.scopes,
|
||||
isMax: tokenData.isMax
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
OAUTH_CONFIG,
|
||||
generateOAuthParams,
|
||||
exchangeCodeForTokens,
|
||||
parseCallbackUrl,
|
||||
formatClaudeCredentials,
|
||||
generateState,
|
||||
generateCodeVerifier,
|
||||
generateCodeChallenge,
|
||||
generateAuthUrl,
|
||||
createProxyAgent
|
||||
};
|
||||
OAUTH_CONFIG,
|
||||
generateOAuthParams,
|
||||
exchangeCodeForTokens,
|
||||
parseCallbackUrl,
|
||||
formatClaudeCredentials,
|
||||
generateState,
|
||||
generateCodeVerifier,
|
||||
generateCodeChallenge,
|
||||
generateAuthUrl,
|
||||
createProxyAgent
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
const crypto = require('crypto');
|
||||
const logger = require('./logger');
|
||||
const crypto = require('crypto')
|
||||
const logger = require('./logger')
|
||||
|
||||
class SessionHelper {
|
||||
/**
|
||||
@@ -10,92 +10,104 @@ class SessionHelper {
|
||||
*/
|
||||
generateSessionHash(requestBody) {
|
||||
if (!requestBody || typeof requestBody !== 'object') {
|
||||
return null;
|
||||
return null
|
||||
}
|
||||
|
||||
let cacheableContent = '';
|
||||
const system = requestBody.system || '';
|
||||
const messages = requestBody.messages || [];
|
||||
let cacheableContent = ''
|
||||
const system = requestBody.system || ''
|
||||
const messages = requestBody.messages || []
|
||||
|
||||
// 1. 优先提取带有cache_control: {"type": "ephemeral"}的内容
|
||||
// 检查system中的cacheable内容
|
||||
if (Array.isArray(system)) {
|
||||
for (const part of system) {
|
||||
if (part && part.cache_control && part.cache_control.type === 'ephemeral') {
|
||||
cacheableContent += part.text || '';
|
||||
cacheableContent += part.text || ''
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查messages中的cacheable内容
|
||||
for (const msg of messages) {
|
||||
const content = msg.content || '';
|
||||
const content = msg.content || ''
|
||||
if (Array.isArray(content)) {
|
||||
for (const part of content) {
|
||||
if (part && part.cache_control && part.cache_control.type === 'ephemeral') {
|
||||
if (part.type === 'text') {
|
||||
cacheableContent += part.text || '';
|
||||
cacheableContent += part.text || ''
|
||||
}
|
||||
// 其他类型(如image)不参与hash计算
|
||||
}
|
||||
}
|
||||
} else if (typeof content === 'string' && msg.cache_control && msg.cache_control.type === 'ephemeral') {
|
||||
} else if (
|
||||
typeof content === 'string' &&
|
||||
msg.cache_control &&
|
||||
msg.cache_control.type === 'ephemeral'
|
||||
) {
|
||||
// 罕见情况,但需要检查
|
||||
cacheableContent += content;
|
||||
cacheableContent += content
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 如果有cacheable内容,直接使用
|
||||
if (cacheableContent) {
|
||||
const hash = crypto.createHash('sha256').update(cacheableContent).digest('hex').substring(0, 32);
|
||||
logger.debug(`📋 Session hash generated from cacheable content: ${hash}`);
|
||||
return hash;
|
||||
const hash = crypto
|
||||
.createHash('sha256')
|
||||
.update(cacheableContent)
|
||||
.digest('hex')
|
||||
.substring(0, 32)
|
||||
logger.debug(`📋 Session hash generated from cacheable content: ${hash}`)
|
||||
return hash
|
||||
}
|
||||
|
||||
// 3. Fallback: 使用system内容
|
||||
if (system) {
|
||||
let systemText = '';
|
||||
let systemText = ''
|
||||
if (typeof system === 'string') {
|
||||
systemText = system;
|
||||
systemText = system
|
||||
} else if (Array.isArray(system)) {
|
||||
systemText = system.map(part => part.text || '').join('');
|
||||
systemText = system.map((part) => part.text || '').join('')
|
||||
}
|
||||
|
||||
|
||||
if (systemText) {
|
||||
const hash = crypto.createHash('sha256').update(systemText).digest('hex').substring(0, 32);
|
||||
logger.debug(`📋 Session hash generated from system content: ${hash}`);
|
||||
return hash;
|
||||
const hash = crypto.createHash('sha256').update(systemText).digest('hex').substring(0, 32)
|
||||
logger.debug(`📋 Session hash generated from system content: ${hash}`)
|
||||
return hash
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 最后fallback: 使用第一条消息内容
|
||||
if (messages.length > 0) {
|
||||
const firstMessage = messages[0];
|
||||
let firstMessageText = '';
|
||||
|
||||
const firstMessage = messages[0]
|
||||
let firstMessageText = ''
|
||||
|
||||
if (typeof firstMessage.content === 'string') {
|
||||
firstMessageText = firstMessage.content;
|
||||
firstMessageText = firstMessage.content
|
||||
} else if (Array.isArray(firstMessage.content)) {
|
||||
if (!firstMessage.content) {
|
||||
logger.error('📋 Session hash generated from first message failed: ', firstMessage);
|
||||
logger.error('📋 Session hash generated from first message failed: ', firstMessage)
|
||||
}
|
||||
|
||||
firstMessageText = firstMessage.content
|
||||
.filter(part => part.type === 'text')
|
||||
.map(part => part.text || '')
|
||||
.join('');
|
||||
.filter((part) => part.type === 'text')
|
||||
.map((part) => part.text || '')
|
||||
.join('')
|
||||
}
|
||||
|
||||
|
||||
if (firstMessageText) {
|
||||
const hash = crypto.createHash('sha256').update(firstMessageText).digest('hex').substring(0, 32);
|
||||
logger.debug(`📋 Session hash generated from first message: ${hash}`);
|
||||
return hash;
|
||||
const hash = crypto
|
||||
.createHash('sha256')
|
||||
.update(firstMessageText)
|
||||
.digest('hex')
|
||||
.substring(0, 32)
|
||||
logger.debug(`📋 Session hash generated from first message: ${hash}`)
|
||||
return hash
|
||||
}
|
||||
}
|
||||
|
||||
// 无法生成会话哈希
|
||||
logger.debug('📋 Unable to generate session hash - no suitable content found');
|
||||
return null;
|
||||
logger.debug('📋 Unable to generate session hash - no suitable content found')
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -104,7 +116,7 @@ class SessionHelper {
|
||||
* @returns {string} - Redis键名
|
||||
*/
|
||||
getSessionRedisKey(sessionHash) {
|
||||
return `sticky_session:${sessionHash}`;
|
||||
return `sticky_session:${sessionHash}`
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -113,10 +125,12 @@ class SessionHelper {
|
||||
* @returns {boolean} - 是否有效
|
||||
*/
|
||||
isValidSessionHash(sessionHash) {
|
||||
return typeof sessionHash === 'string' &&
|
||||
sessionHash.length === 32 &&
|
||||
/^[a-f0-9]{32}$/.test(sessionHash);
|
||||
return (
|
||||
typeof sessionHash === 'string' &&
|
||||
sessionHash.length === 32 &&
|
||||
/^[a-f0-9]{32}$/.test(sessionHash)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = new SessionHelper();
|
||||
module.exports = new SessionHelper()
|
||||
|
||||
@@ -11,29 +11,29 @@
|
||||
*/
|
||||
function maskToken(token, visiblePercent = 70) {
|
||||
if (!token || typeof token !== 'string') {
|
||||
return '[EMPTY]';
|
||||
return '[EMPTY]'
|
||||
}
|
||||
|
||||
const length = token.length;
|
||||
|
||||
const { length } = token
|
||||
|
||||
// 对于非常短的 token,至少隐藏一部分
|
||||
if (length <= 10) {
|
||||
return token.slice(0, 5) + '*'.repeat(length - 5);
|
||||
return token.slice(0, 5) + '*'.repeat(length - 5)
|
||||
}
|
||||
|
||||
// 计算可见字符数量
|
||||
const visibleLength = Math.floor(length * (visiblePercent / 100));
|
||||
|
||||
const visibleLength = Math.floor(length * (visiblePercent / 100))
|
||||
|
||||
// 在前部和尾部分配可见字符
|
||||
const frontLength = Math.ceil(visibleLength * 0.6);
|
||||
const backLength = visibleLength - frontLength;
|
||||
|
||||
const frontLength = Math.ceil(visibleLength * 0.6)
|
||||
const backLength = visibleLength - frontLength
|
||||
|
||||
// 构建脱敏后的 token
|
||||
const front = token.slice(0, frontLength);
|
||||
const back = token.slice(-backLength);
|
||||
const middle = '*'.repeat(length - visibleLength);
|
||||
|
||||
return `${front}${middle}${back}`;
|
||||
const front = token.slice(0, frontLength)
|
||||
const back = token.slice(-backLength)
|
||||
const middle = '*'.repeat(length - visibleLength)
|
||||
|
||||
return `${front}${middle}${back}`
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -42,20 +42,23 @@ function maskToken(token, visiblePercent = 70) {
|
||||
* @param {Array<string>} tokenFields - 需要脱敏的字段名列表
|
||||
* @returns {Object} 脱敏后的对象副本
|
||||
*/
|
||||
function maskTokensInObject(obj, tokenFields = ['accessToken', 'refreshToken', 'access_token', 'refresh_token']) {
|
||||
function maskTokensInObject(
|
||||
obj,
|
||||
tokenFields = ['accessToken', 'refreshToken', 'access_token', 'refresh_token']
|
||||
) {
|
||||
if (!obj || typeof obj !== 'object') {
|
||||
return obj;
|
||||
return obj
|
||||
}
|
||||
|
||||
const masked = { ...obj };
|
||||
|
||||
tokenFields.forEach(field => {
|
||||
const masked = { ...obj }
|
||||
|
||||
tokenFields.forEach((field) => {
|
||||
if (masked[field]) {
|
||||
masked[field] = maskToken(masked[field]);
|
||||
masked[field] = maskToken(masked[field])
|
||||
}
|
||||
});
|
||||
|
||||
return masked;
|
||||
})
|
||||
|
||||
return masked
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -75,21 +78,21 @@ function formatTokenRefreshLog(accountId, accountName, tokens, status, message =
|
||||
accountName,
|
||||
status,
|
||||
message
|
||||
};
|
||||
}
|
||||
|
||||
if (tokens) {
|
||||
log.tokens = {
|
||||
accessToken: tokens.accessToken ? maskToken(tokens.accessToken) : '[NOT_PROVIDED]',
|
||||
refreshToken: tokens.refreshToken ? maskToken(tokens.refreshToken) : '[NOT_PROVIDED]',
|
||||
expiresAt: tokens.expiresAt || '[NOT_PROVIDED]'
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return log;
|
||||
return log
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
maskToken,
|
||||
maskTokensInObject,
|
||||
formatTokenRefreshLog
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const winston = require('winston');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
const { maskToken } = require('./tokenMask');
|
||||
const winston = require('winston')
|
||||
const path = require('path')
|
||||
const fs = require('fs')
|
||||
const { maskToken } = require('./tokenMask')
|
||||
|
||||
// 确保日志目录存在
|
||||
const logDir = path.join(process.cwd(), 'logs');
|
||||
const logDir = path.join(process.cwd(), 'logs')
|
||||
if (!fs.existsSync(logDir)) {
|
||||
fs.mkdirSync(logDir, { recursive: true });
|
||||
fs.mkdirSync(logDir, { recursive: true })
|
||||
}
|
||||
|
||||
// 创建专用的 token 刷新日志记录器
|
||||
@@ -17,9 +17,7 @@ const tokenRefreshLogger = winston.createLogger({
|
||||
format: 'YYYY-MM-DD HH:mm:ss.SSS'
|
||||
}),
|
||||
winston.format.json(),
|
||||
winston.format.printf(info => {
|
||||
return JSON.stringify(info, null, 2);
|
||||
})
|
||||
winston.format.printf((info) => JSON.stringify(info, null, 2))
|
||||
),
|
||||
transports: [
|
||||
// 文件传输 - 每日轮转
|
||||
@@ -39,16 +37,15 @@ const tokenRefreshLogger = winston.createLogger({
|
||||
],
|
||||
// 错误处理
|
||||
exitOnError: false
|
||||
});
|
||||
})
|
||||
|
||||
// 在开发环境添加控制台输出
|
||||
if (process.env.NODE_ENV !== 'production') {
|
||||
tokenRefreshLogger.add(new winston.transports.Console({
|
||||
format: winston.format.combine(
|
||||
winston.format.colorize(),
|
||||
winston.format.simple()
|
||||
)
|
||||
}));
|
||||
tokenRefreshLogger.add(
|
||||
new winston.transports.Console({
|
||||
format: winston.format.combine(winston.format.colorize(), winston.format.simple())
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -62,7 +59,7 @@ function logRefreshStart(accountId, accountName, platform = 'claude', reason = '
|
||||
platform,
|
||||
reason,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -74,7 +71,7 @@ function logRefreshSuccess(accountId, accountName, platform = 'claude', tokenDat
|
||||
refreshToken: tokenData.refreshToken ? maskToken(tokenData.refreshToken) : '[NOT_PROVIDED]',
|
||||
expiresAt: tokenData.expiresAt || tokenData.expiry_date || '[NOT_PROVIDED]',
|
||||
scopes: tokenData.scopes || tokenData.scope || '[NOT_PROVIDED]'
|
||||
};
|
||||
}
|
||||
|
||||
tokenRefreshLogger.info({
|
||||
event: 'token_refresh_success',
|
||||
@@ -83,7 +80,7 @@ function logRefreshSuccess(accountId, accountName, platform = 'claude', tokenDat
|
||||
platform,
|
||||
tokenData: maskedTokenData,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -95,7 +92,7 @@ function logRefreshError(accountId, accountName, platform = 'claude', error, att
|
||||
code: error.code || 'UNKNOWN',
|
||||
statusCode: error.response?.status || 'N/A',
|
||||
responseData: error.response?.data || 'N/A'
|
||||
};
|
||||
}
|
||||
|
||||
tokenRefreshLogger.error({
|
||||
event: 'token_refresh_error',
|
||||
@@ -105,7 +102,7 @@ function logRefreshError(accountId, accountName, platform = 'claude', error, att
|
||||
error: errorInfo,
|
||||
attemptNumber,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -119,7 +116,7 @@ function logRefreshSkipped(accountId, accountName, platform = 'claude', reason =
|
||||
platform,
|
||||
reason,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -135,7 +132,7 @@ function logTokenUsage(accountId, accountName, platform = 'claude', expiresAt, i
|
||||
isExpired,
|
||||
remainingMinutes: expiresAt ? Math.floor((new Date(expiresAt) - Date.now()) / 60000) : 'N/A',
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -147,7 +144,7 @@ function logBatchRefreshStart(totalAccounts, platform = 'all') {
|
||||
totalAccounts,
|
||||
platform,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -163,7 +160,7 @@ function logBatchRefreshComplete(results) {
|
||||
skipped: results.skipped || 0
|
||||
},
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
@@ -175,4 +172,4 @@ module.exports = {
|
||||
logTokenUsage,
|
||||
logBatchRefreshStart,
|
||||
logBatchRefreshComplete
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user