feat: 为API Key添加模型限制功能

- 前端:在API Key创建和编辑表单中添加模型限制开关和标签输入
- 前端:支持动态添加/删除限制的模型列表
- 后端:更新API Key数据结构,新增enableModelRestriction和restrictedModels字段
- 后端:在中转请求时检查模型访问权限
- 修复:Enter键提交表单问题,使用@keydown.enter.prevent
- 优化:限制模型数据持久化,关闭开关时不清空数据

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
shaw
2025-07-19 20:54:26 +08:00
parent f9933f7061
commit f962083752
9 changed files with 278 additions and 22 deletions

View File

@@ -109,7 +109,9 @@ const authenticateApiKey = async (req, res, next) => {
name: validation.keyData.name,
tokenLimit: validation.keyData.tokenLimit,
claudeAccountId: validation.keyData.claudeAccountId,
concurrencyLimit: validation.keyData.concurrencyLimit
concurrencyLimit: validation.keyData.concurrencyLimit,
enableModelRestriction: validation.keyData.enableModelRestriction,
restrictedModels: validation.keyData.restrictedModels
};
req.usage = validation.keyData.usage;

View File

@@ -2,6 +2,26 @@ const Redis = require('ioredis');
const config = require('../../config/config');
const logger = require('../utils/logger');
// 时区辅助函数
function getDateInTimezone(date = new Date()) {
const offset = config.system.timezoneOffset || 8; // 默认UTC+8
const utcTime = date.getTime() + (date.getTimezoneOffset() * 60000);
const targetTime = new Date(utcTime + (offset * 3600000));
return targetTime;
}
// 获取配置时区的日期字符串 (YYYY-MM-DD)
function getDateStringInTimezone(date = new Date()) {
const tzDate = getDateInTimezone(date);
return `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}-${String(tzDate.getDate()).padStart(2, '0')}`;
}
// 获取配置时区的小时 (0-23)
function getHourInTimezone(date = new Date()) {
const tzDate = getDateInTimezone(date);
return tzDate.getHours();
}
class RedisClient {
constructor() {
this.client = null;
@@ -140,9 +160,10 @@ class RedisClient {
async incrementTokenUsage(keyId, tokens, inputTokens = 0, outputTokens = 0, cacheCreateTokens = 0, cacheReadTokens = 0, model = 'unknown') {
const key = `usage:${keyId}`;
const now = new Date();
const today = now.toISOString().split('T')[0];
const currentMonth = `${now.getFullYear()}-${String(now.getMonth() + 1).padStart(2, '0')}`;
const currentHour = `${today}:${String(now.getHours()).padStart(2, '0')}`; // 新增小时级别
const today = getDateStringInTimezone(now);
const tzDate = getDateInTimezone(now);
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`;
const currentHour = `${today}:${String(getHourInTimezone(now)).padStart(2, '0')}`; // 新增小时级别
const daily = `usage:daily:${keyId}:${today}`;
const monthly = `usage:monthly:${keyId}:${currentMonth}`;
@@ -263,9 +284,10 @@ class RedisClient {
async getUsageStats(keyId) {
const totalKey = `usage:${keyId}`;
const today = new Date().toISOString().split('T')[0];
const today = getDateStringInTimezone();
const dailyKey = `usage:daily:${keyId}:${today}`;
const currentMonth = `${new Date().getFullYear()}-${String(new Date().getMonth() + 1).padStart(2, '0')}`;
const tzDate = getDateInTimezone();
const currentMonth = `${tzDate.getFullYear()}-${String(tzDate.getMonth() + 1).padStart(2, '0')}`;
const monthlyKey = `usage:monthly:${keyId}:${currentMonth}`;
const [total, daily, monthly] = await Promise.all([
@@ -534,7 +556,7 @@ class RedisClient {
// 📊 获取今日系统统计
async getTodayStats() {
try {
const today = new Date().toISOString().split('T')[0];
const today = getDateStringInTimezone();
const dailyKeys = await this.client.keys(`usage:daily:*:${today}`);
let totalRequestsToday = 0;

View File

@@ -32,7 +32,9 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => {
tokenLimit,
expiresAt,
claudeAccountId,
concurrencyLimit
concurrencyLimit,
enableModelRestriction,
restrictedModels
} = req.body;
// 输入验证
@@ -57,13 +59,24 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => {
return res.status(400).json({ error: 'Concurrency limit must be a non-negative integer' });
}
// 验证模型限制字段
if (enableModelRestriction !== undefined && typeof enableModelRestriction !== 'boolean') {
return res.status(400).json({ error: 'Enable model restriction must be a boolean' });
}
if (restrictedModels !== undefined && !Array.isArray(restrictedModels)) {
return res.status(400).json({ error: 'Restricted models must be an array' });
}
const newKey = await apiKeyService.generateApiKey({
name,
description,
tokenLimit,
expiresAt,
claudeAccountId,
concurrencyLimit
concurrencyLimit,
enableModelRestriction,
restrictedModels
});
logger.success(`🔑 Admin created new API key: ${name}`);
@@ -78,9 +91,9 @@ router.post('/api-keys', authenticateAdmin, async (req, res) => {
router.put('/api-keys/:keyId', authenticateAdmin, async (req, res) => {
try {
const { keyId } = req.params;
const { tokenLimit, concurrencyLimit, claudeAccountId } = req.body;
const { tokenLimit, concurrencyLimit, claudeAccountId, enableModelRestriction, restrictedModels } = req.body;
// 只允许更新tokenLimit、concurrencyLimit和claudeAccountId
// 只允许更新指定字段
const updates = {};
if (tokenLimit !== undefined && tokenLimit !== null && tokenLimit !== '') {
@@ -102,6 +115,21 @@ router.put('/api-keys/:keyId', authenticateAdmin, async (req, res) => {
updates.claudeAccountId = claudeAccountId || '';
}
// 处理模型限制字段
if (enableModelRestriction !== undefined) {
if (typeof enableModelRestriction !== 'boolean') {
return res.status(400).json({ error: 'Enable model restriction must be a boolean' });
}
updates.enableModelRestriction = enableModelRestriction;
}
if (restrictedModels !== undefined) {
if (!Array.isArray(restrictedModels)) {
return res.status(400).json({ error: 'Restricted models must be an array' });
}
updates.restrictedModels = restrictedModels;
}
await apiKeyService.updateApiKey(keyId, updates);
logger.success(`📝 Admin updated API key: ${keyId}`);

View File

@@ -18,7 +18,9 @@ class ApiKeyService {
expiresAt = null,
claudeAccountId = null,
isActive = true,
concurrencyLimit = 0
concurrencyLimit = 0,
enableModelRestriction = false,
restrictedModels = []
} = options;
// 生成简单的API Key (64字符十六进制)
@@ -35,6 +37,8 @@ class ApiKeyService {
concurrencyLimit: String(concurrencyLimit ?? 0),
isActive: String(isActive),
claudeAccountId: claudeAccountId || '',
enableModelRestriction: String(enableModelRestriction),
restrictedModels: JSON.stringify(restrictedModels || []),
createdAt: new Date().toISOString(),
lastUsedAt: '',
expiresAt: expiresAt || '',
@@ -55,6 +59,8 @@ class ApiKeyService {
concurrencyLimit: parseInt(keyData.concurrencyLimit),
isActive: keyData.isActive === 'true',
claudeAccountId: keyData.claudeAccountId,
enableModelRestriction: keyData.enableModelRestriction === 'true',
restrictedModels: JSON.parse(keyData.restrictedModels),
createdAt: keyData.createdAt,
expiresAt: keyData.expiresAt,
createdBy: keyData.createdBy
@@ -131,6 +137,12 @@ class ApiKeyService {
key.concurrencyLimit = parseInt(key.concurrencyLimit || 0);
key.currentConcurrency = await redis.getConcurrency(key.id);
key.isActive = key.isActive === 'true';
key.enableModelRestriction = key.enableModelRestriction === 'true';
try {
key.restrictedModels = key.restrictedModels ? JSON.parse(key.restrictedModels) : [];
} catch (e) {
key.restrictedModels = [];
}
delete key.apiKey; // 不返回哈希后的key
}
@@ -150,12 +162,20 @@ class ApiKeyService {
}
// 允许更新的字段
const allowedUpdates = ['name', 'description', 'tokenLimit', 'concurrencyLimit', 'isActive', 'claudeAccountId', 'expiresAt'];
const allowedUpdates = ['name', 'description', 'tokenLimit', 'concurrencyLimit', 'isActive', 'claudeAccountId', 'expiresAt', 'enableModelRestriction', 'restrictedModels'];
const updatedData = { ...keyData };
for (const [field, value] of Object.entries(updates)) {
if (allowedUpdates.includes(field)) {
updatedData[field] = (value != null ? value : '').toString();
if (field === 'restrictedModels') {
// 特殊处理 restrictedModels 数组
updatedData[field] = JSON.stringify(value || []);
} else if (field === 'enableModelRestriction') {
// 布尔值转字符串
updatedData[field] = String(value);
} else {
updatedData[field] = (value != null ? value : '').toString();
}
}
}

View File

@@ -22,6 +22,24 @@ class ClaudeRelayService {
let upstreamRequest = null;
try {
// 检查模型限制
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels && apiKeyData.restrictedModels.length > 0) {
const requestedModel = requestBody.model;
if (requestedModel && apiKeyData.restrictedModels.includes(requestedModel)) {
logger.warn(`🚫 Model restriction violation for key ${apiKeyData.name}: Attempted to use restricted model ${requestedModel}`);
return {
statusCode: 403,
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
error: {
type: 'forbidden',
message: '暂无该模型访问权限'
}
})
};
}
}
// 生成会话哈希用于sticky会话
const sessionHash = sessionHelper.generateSessionHash(requestBody);
@@ -419,6 +437,26 @@ class ClaudeRelayService {
// 🌊 处理流式响应带usage数据捕获
async relayStreamRequestWithUsageCapture(requestBody, apiKeyData, responseStream, clientHeaders, usageCallback) {
try {
// 检查模型限制
if (apiKeyData.enableModelRestriction && apiKeyData.restrictedModels && apiKeyData.restrictedModels.length > 0) {
const requestedModel = requestBody.model;
if (requestedModel && apiKeyData.restrictedModels.includes(requestedModel)) {
logger.warn(`🚫 Model restriction violation for key ${apiKeyData.name}: Attempted to use restricted model ${requestedModel}`);
// 对于流式响应,需要写入错误并结束流
const errorResponse = JSON.stringify({
error: {
type: 'forbidden',
message: '暂无该模型访问权限'
}
});
responseStream.writeHead(403, { 'Content-Type': 'application/json' });
responseStream.end(errorResponse);
return;
}
}
// 生成会话哈希用于sticky会话
const sessionHash = sessionHelper.generateSessionHash(requestBody);