mirror of
https://github.com/Wei-Shaw/claude-relay-service.git
synced 2026-01-22 16:43:35 +00:00
feat: 为API Key添加模型限制功能
- 前端:在API Key创建和编辑表单中添加模型限制开关和标签输入 - 前端:支持动态添加/删除限制的模型列表 - 后端:更新API Key数据结构,新增enableModelRestriction和restrictedModels字段 - 后端:在中转请求时检查模型访问权限 - 修复:Enter键提交表单问题,使用@keydown.enter.prevent - 优化:限制模型数据持久化,关闭开关时不清空数据 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -109,7 +109,9 @@ const authenticateApiKey = async (req, res, next) => {
|
||||
name: validation.keyData.name,
|
||||
tokenLimit: validation.keyData.tokenLimit,
|
||||
claudeAccountId: validation.keyData.claudeAccountId,
|
||||
concurrencyLimit: validation.keyData.concurrencyLimit
|
||||
concurrencyLimit: validation.keyData.concurrencyLimit,
|
||||
enableModelRestriction: validation.keyData.enableModelRestriction,
|
||||
restrictedModels: validation.keyData.restrictedModels
|
||||
};
|
||||
req.usage = validation.keyData.usage;
|
||||
|
||||
|
||||
@@ -2,6 +2,26 @@ const Redis = require('ioredis');
|
||||
const config = require('../../config/config');
|
||||
const logger = require('../utils/logger');
|
||||
|
||||
// 时区辅助函数
|
||||
function getDateInTimezone(date = new Date()) {
|
||||
const offset = config.system.timezoneOffset || 8; // 默认UTC+8
|
||||
const utcTime = date.getTime() + (date.getTimezoneOffset() * 60000);
|
||||
const targetTime = new Date(utcTime + (offset * 3600000));
|
||||
return targetTime;
|
||||
}
|
||||
|
||||
// 获取配置时区的日期字符串 (YYYY-MM-DD)
|
||||
function getDateStringInTimezone(date = new Date()) {
|
||||
const tzDate = getDateInTimezone(date);
|
||||
return `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}-${String(tzDate.getDate()).padStart(2, '0')}`;
|
||||
}
|
||||
|
||||
// 获取配置时区的小时 (0-23)
|
||||
function getHourInTimezone(date = new Date()) {
|
||||
const tzDate = getDateInTimezone(date);
|
||||
return tzDate.getHours();
|
||||
}
|
||||
|
||||
class RedisClient {
|
||||
constructor() {
|
||||
this.client = null;
|
||||
@@ -140,9 +160,10 @@ class RedisClient {
|
||||
async incrementTokenUsage(keyId, tokens, inputTokens = 0, outputTokens = 0, cacheCreateTokens = 0, cacheReadTokens = 0, model = 'unknown') {
|
||||
const key = `usage:${keyId}`;
|
||||
const now = new Date();
|
||||
const today = now.toISOString().split('T')[0];
|
||||
const currentMonth = `${now.getFullYear()}-${String(now.getMonth() + 1).padStart(2, '0')}`;
|
||||
const currentHour = `${today}:${String(now.getHours()).padStart(2, '0')}`; // 新增小时级别
|
||||
const today = getDateStringInTimezone(now);
|
||||
const tzDate = getDateInTimezone(now);
|
||||
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`;
|
||||
const currentHour = `${today}:${String(getHourInTimezone(now)).padStart(2, '0')}`; // 新增小时级别
|
||||
|
||||
const daily = `usage:daily:${keyId}:${today}`;
|
||||
const monthly = `usage:monthly:${keyId}:${currentMonth}`;
|
||||
@@ -263,9 +284,10 @@ class RedisClient {
|
||||
|
||||
async getUsageStats(keyId) {
|
||||
const totalKey = `usage:${keyId}`;
|
||||
const today = new Date().toISOString().split('T')[0];
|
||||
const today = getDateStringInTimezone();
|
||||
const dailyKey = `usage:daily:${keyId}:${today}`;
|
||||
const currentMonth = `${new Date().getFullYear()}-${String(new Date().getMonth() + 1).padStart(2, '0')}`;
|
||||
const tzDate = getDateInTimezone();
|
||||
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`;
|
||||
const monthlyKey = `usage:monthly:${keyId}:${currentMonth}`;
|
||||
|
||||
const [total, daily, monthly] = await Promise.all([
|
||||
@@ -534,7 +556,7 @@ class RedisClient {
|
||||
// 📊 获取今日系统统计
|
||||
async getTodayStats() {
|
||||
try {
|
||||
const today = new Date().toISOString().split('T')[0];
|
||||
const today = getDateStringInTimezone();
|
||||
const dailyKeys = await this.client.keys(`usage:daily:*:${today}`);
|
||||
|
||||
let totalRequestsToday = 0;
|
||||
|
||||
@@ -32,7 +32,9 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => {
|
||||
tokenLimit,
|
||||
expiresAt,
|
||||
claudeAccountId,
|
||||
concurrencyLimit
|
||||
concurrencyLimit,
|
||||
enableModelRestriction,
|
||||
restrictedModels
|
||||
} = req.body;
|
||||
|
||||
// 输入验证
|
||||
@@ -57,13 +59,24 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => {
|
||||
return res.status(400).json({ error: 'Concurrency limit must be a non-negative integer' });
|
||||
}
|
||||
|
||||
// 验证模型限制字段
|
||||
if (enableModelRestriction !== undefined && typeof enableModelRestriction !== 'boolean') {
|
||||
return res.status(400).json({ error: 'Enable model restriction must be a boolean' });
|
||||
}
|
||||
|
||||
if (restrictedModels !== undefined && !Array.isArray(restrictedModels)) {
|
||||
return res.status(400).json({ error: 'Restricted models must be an array' });
|
||||
}
|
||||
|
||||
const newKey = await apiKeyService.generateApiKey({
|
||||
name,
|
||||
description,
|
||||
tokenLimit,
|
||||
expiresAt,
|
||||
claudeAccountId,
|
||||
concurrencyLimit
|
||||
concurrencyLimit,
|
||||
enableModelRestriction,
|
||||
restrictedModels
|
||||
});
|
||||
|
||||
logger.success(`🔑 Admin created new API key: ${name}`);
|
||||
@@ -78,9 +91,9 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => {
|
||||
router.put('/api-keys/:keyId', authenticateAdmin, async (req, res) => {
|
||||
try {
|
||||
const { keyId } = req.params;
|
||||
const { tokenLimit, concurrencyLimit, claudeAccountId } = req.body;
|
||||
const { tokenLimit, concurrencyLimit, claudeAccountId, enableModelRestriction, restrictedModels } = req.body;
|
||||
|
||||
// 只允许更新tokenLimit、concurrencyLimit和claudeAccountId
|
||||
// 只允许更新指定字段
|
||||
const updates = {};
|
||||
|
||||
if (tokenLimit !== undefined && tokenLimit !== null && tokenLimit !== '') {
|
||||
@@ -102,6 +115,21 @@ router.put('/api-keys/:keyId', authenticateAdmin, async (req, res) => {
|
||||
updates.claudeAccountId = claudeAccountId || '';
|
||||
}
|
||||
|
||||
// 处理模型限制字段
|
||||
if (enableModelRestriction !== undefined) {
|
||||
if (typeof enableModelRestriction !== 'boolean') {
|
||||
return res.status(400).json({ error: 'Enable model restriction must be a boolean' });
|
||||
}
|
||||
updates.enableModelRestriction = enableModelRestriction;
|
||||
}
|
||||
|
||||
if (restrictedModels !== undefined) {
|
||||
if (!Array.isArray(restrictedModels)) {
|
||||
return res.status(400).json({ error: 'Restricted models must be an array' });
|
||||
}
|
||||
updates.restrictedModels = restrictedModels;
|
||||
}
|
||||
|
||||
await apiKeyService.updateApiKey(keyId, updates);
|
||||
|
||||
logger.success(`📝 Admin updated API key: ${keyId}`);
|
||||
|
||||
@@ -18,7 +18,9 @@ class ApiKeyService {
|
||||
expiresAt = null,
|
||||
claudeAccountId = null,
|
||||
isActive = true,
|
||||
concurrencyLimit = 0
|
||||
concurrencyLimit = 0,
|
||||
enableModelRestriction = false,
|
||||
restrictedModels = []
|
||||
} = options;
|
||||
|
||||
// 生成简单的API Key (64字符十六进制)
|
||||
@@ -35,6 +37,8 @@ class ApiKeyService {
|
||||
concurrencyLimit: String(concurrencyLimit ?? 0),
|
||||
isActive: String(isActive),
|
||||
claudeAccountId: claudeAccountId || '',
|
||||
enableModelRestriction: String(enableModelRestriction),
|
||||
restrictedModels: JSON.stringify(restrictedModels || []),
|
||||
createdAt: new Date().toISOString(),
|
||||
lastUsedAt: '',
|
||||
expiresAt: expiresAt || '',
|
||||
@@ -55,6 +59,8 @@ class ApiKeyService {
|
||||
concurrencyLimit: parseInt(keyData.concurrencyLimit),
|
||||
isActive: keyData.isActive === 'true',
|
||||
claudeAccountId: keyData.claudeAccountId,
|
||||
enableModelRestriction: keyData.enableModelRestriction === 'true',
|
||||
restrictedModels: JSON.parse(keyData.restrictedModels),
|
||||
createdAt: keyData.createdAt,
|
||||
expiresAt: keyData.expiresAt,
|
||||
createdBy: keyData.createdBy
|
||||
@@ -131,6 +137,12 @@ class ApiKeyService {
|
||||
key.concurrencyLimit = parseInt(key.concurrencyLimit || 0);
|
||||
key.currentConcurrency = await redis.getConcurrency(key.id);
|
||||
key.isActive = key.isActive === 'true';
|
||||
key.enableModelRestriction = key.enableModelRestriction === 'true';
|
||||
try {
|
||||
key.restrictedModels = key.restrictedModels ? JSON.parse(key.restrictedModels) : [];
|
||||
} catch (e) {
|
||||
key.restrictedModels = [];
|
||||
}
|
||||
delete key.apiKey; // 不返回哈希后的key
|
||||
}
|
||||
|
||||
@@ -150,12 +162,20 @@ class ApiKeyService {
|
||||
}
|
||||
|
||||
// 允许更新的字段
|
||||
const allowedUpdates = ['name', 'description', 'tokenLimit', 'concurrencyLimit', 'isActive', 'claudeAccountId', 'expiresAt'];
|
||||
const allowedUpdates = ['name', 'description', 'tokenLimit', 'concurrencyLimit', 'isActive', 'claudeAccountId', 'expiresAt', 'enableModelRestriction', 'restrictedModels'];
|
||||
const updatedData = { ...keyData };
|
||||
|
||||
for (const [field, value] of Object.entries(updates)) {
|
||||
if (allowedUpdates.includes(field)) {
|
||||
updatedData[field] = (value != null ? value : '').toString();
|
||||
if (field === 'restrictedModels') {
|
||||
// 特殊处理 restrictedModels 数组
|
||||
updatedData[field] = JSON.stringify(value || []);
|
||||
} else if (field === 'enableModelRestriction') {
|
||||
// 布尔值转字符串
|
||||
updatedData[field] = String(value);
|
||||
} else {
|
||||
updatedData[field] = (value != null ? value : '').toString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,24 @@ class ClaudeRelayService {
|
||||
let upstreamRequest = null;
|
||||
|
||||
try {
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels && apiKeyData.restrictedModels.length > 0) {
|
||||
const requestedModel = requestBody.model;
|
||||
if (requestedModel && apiKeyData.restrictedModels.includes(requestedModel)) {
|
||||
logger.warn(`🚫 Model restriction violation for key ${apiKeyData.name}: Attempted to use restricted model ${requestedModel}`);
|
||||
return {
|
||||
statusCode: 403,
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
error: {
|
||||
type: 'forbidden',
|
||||
message: '暂无该模型访问权限'
|
||||
}
|
||||
})
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(requestBody);
|
||||
|
||||
@@ -419,6 +437,26 @@ class ClaudeRelayService {
|
||||
// 🌊 处理流式响应(带usage数据捕获)
|
||||
async relayStreamRequestWithUsageCapture(requestBody, apiKeyData, responseStream, clientHeaders, usageCallback) {
|
||||
try {
|
||||
// 检查模型限制
|
||||
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels && apiKeyData.restrictedModels.length > 0) {
|
||||
const requestedModel = requestBody.model;
|
||||
if (requestedModel && apiKeyData.restrictedModels.includes(requestedModel)) {
|
||||
logger.warn(`🚫 Model restriction violation for key ${apiKeyData.name}: Attempted to use restricted model ${requestedModel}`);
|
||||
|
||||
// 对于流式响应,需要写入错误并结束流
|
||||
const errorResponse = JSON.stringify({
|
||||
error: {
|
||||
type: 'forbidden',
|
||||
message: '暂无该模型访问权限'
|
||||
}
|
||||
});
|
||||
|
||||
responseStream.writeHead(403, { 'Content-Type': 'application/json' });
|
||||
responseStream.end(errorResponse);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 生成会话哈希用于sticky会话
|
||||
const sessionHash = sessionHelper.generateSessionHash(requestBody);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user