Merge commit from fork

fix: harden token search with pagination, rate limiting and input validation
This commit is contained in:
Calcium-Ion
2026-02-06 17:54:40 +08:00
committed by GitHub
11 changed files with 282 additions and 20 deletions

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
return tokens, err
}
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
// 规则:
// 1. 转义 _ 和 \(不允许 _ 作通配符)
// 2. 连续的 % 合并为单个 %
// 3. 最多允许 2 个 %
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
// 5. 不含 % 时按精确匹配
func sanitizeLikePattern(input string) (string, error) {
// 1. 转义 \ 和 _
input = strings.ReplaceAll(input, `\`, `\\`)
input = strings.ReplaceAll(input, `_`, `\_`)
// 2. 连续的 % 直接拒绝
if strings.Contains(input, "%%") {
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
}
// 3. 统计 % 数量,不得超过 2
count := strings.Count(input, "%")
if count > 2 {
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
}
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
if count > 0 {
stripped := strings.ReplaceAll(input, "%", "")
if len(stripped) < 2 {
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
}
return input, nil
}
// 5. 无 % 时,精确全匹配
return input, nil
}
const searchHardLimit = 100
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
// model 层强制截断
if limit <= 0 || limit > searchHardLimit {
limit = searchHardLimit
}
if offset < 0 {
offset = 0
}
if token != "" {
token = strings.Trim(token, "sk-")
}
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
maxTokens := operation_setting.GetMaxUserTokens()
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
if hasFuzzy {
count, err := CountUserTokens(userId)
if err != nil {
common.SysLog("failed to count user tokens: " + err.Error())
return nil, 0, errors.New("获取令牌数量失败")
}
if int(count) > maxTokens {
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
}
}
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
if keyword != "" {
keywordPattern, err := sanitizeLikePattern(keyword)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
}
if token != "" {
tokenPattern, err := sanitizeLikePattern(token)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
}
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT
err = baseQuery.Limit(maxTokens).Count(&total).Error
if err != nil {
common.SysError("failed to count search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
// 再分页查数据
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
if err != nil {
common.SysError("failed to search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
return tokens, total, nil
}
func ValidateUserToken(key string) (token *Token, err error) {