Merge commit from fork

fix: harden token search with pagination, rate limiting and input validation
This commit is contained in:
Calcium-Ion
2026-02-06 17:54:40 +08:00
committed by GitHub
11 changed files with 282 additions and 20 deletions

View File

@@ -175,6 +175,10 @@ var (
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
// Per-user search rate limit (applies after authentication, keyed by user ID)
SearchRateLimitNum = 10
SearchRateLimitDuration int64 = 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute

View File

@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
case bool:
if inter.(bool) {
return "true"

View File

@@ -8,6 +8,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/i18n"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
@@ -31,16 +32,17 @@ func SearchTokens(c *gin.Context) {
userId := c.GetInt("id")
keyword := c.Query("keyword")
token := c.Query("token")
tokens, err := model.SearchUserTokens(userId, keyword, token)
pageInfo := common.GetPageQuery(c)
tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": tokens,
})
pageInfo.SetTotal(int(total))
pageInfo.SetItems(tokens)
common.ApiSuccess(c, pageInfo)
return
}
@@ -157,6 +159,20 @@ func AddToken(c *gin.Context) {
return
}
}
// 检查用户令牌数量是否已达上限
maxTokens := operation_setting.GetMaxUserTokens()
count, err := model.CountUserTokens(c.GetInt("id"))
if err != nil {
common.ApiError(c, err)
return
}
if int(count) >= maxTokens {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
})
return
}
key, err := common.GenerateKey()
if err != nil {
common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)

View File

@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
}
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
// instead of client IP, making it resistant to proxy rotation attacks.
// Must be used AFTER authentication middleware (UserAuth).
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
if common.RedisEnabled {
return func(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
userRedisRateLimiter(c, maxRequestNum, duration, key)
}
}
// It's safe to call multi times.
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.Status(http.StatusUnauthorized)
c.Abort()
return
}
key := fmt.Sprintf("%s:user:%d", mark, userId)
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
c.Status(http.StatusTooManyRequests)
c.Abort()
return
}
}
}
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
// (to support user-ID-based keys).
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
ctx := context.Background()
rdb := common.RDB
listLength, err := rdb.LLen(ctx, key).Result()
if err != nil {
fmt.Println(err.Error())
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
} else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr)
if err != nil {
fmt.Println(err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
nowTimeStr := time.Now().Format(timeFormat)
nowTime, err := time.Parse(timeFormat, nowTimeStr)
if err != nil {
fmt.Println(err)
c.Status(http.StatusInternalServerError)
c.Abort()
return
}
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests)
c.Abort()
return
} else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
}
}
}
// SearchRateLimit returns a per-user rate limiter for search endpoints.
// 10 requests per 60 seconds per user (by user ID, not IP).
func SearchRateLimit() func(c *gin.Context) {
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
return tokens, err
}
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
// 规则:
// 1. 转义 _ 和 \(不允许 _ 作通配符)
// 2. 连续的 % 合并为单个 %
// 3. 最多允许 2 个 %
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
// 5. 不含 % 时按精确匹配
func sanitizeLikePattern(input string) (string, error) {
// 1. 转义 \ 和 _
input = strings.ReplaceAll(input, `\`, `\\`)
input = strings.ReplaceAll(input, `_`, `\_`)
// 2. 连续的 % 直接拒绝
if strings.Contains(input, "%%") {
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
}
// 3. 统计 % 数量,不得超过 2
count := strings.Count(input, "%")
if count > 2 {
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
}
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
if count > 0 {
stripped := strings.ReplaceAll(input, "%", "")
if len(stripped) < 2 {
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
}
return input, nil
}
// 5. 无 % 时,精确全匹配
return input, nil
}
const searchHardLimit = 100
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
// model 层强制截断
if limit <= 0 || limit > searchHardLimit {
limit = searchHardLimit
}
if offset < 0 {
offset = 0
}
if token != "" {
token = strings.Trim(token, "sk-")
}
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
maxTokens := operation_setting.GetMaxUserTokens()
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
if hasFuzzy {
count, err := CountUserTokens(userId)
if err != nil {
common.SysLog("failed to count user tokens: " + err.Error())
return nil, 0, errors.New("获取令牌数量失败")
}
if int(count) > maxTokens {
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
}
}
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
if keyword != "" {
keywordPattern, err := sanitizeLikePattern(keyword)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
}
if token != "" {
tokenPattern, err := sanitizeLikePattern(token)
if err != nil {
return nil, 0, err
}
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
}
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT
err = baseQuery.Limit(maxTokens).Count(&total).Error
if err != nil {
common.SysError("failed to count search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
// 再分页查数据
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
if err != nil {
common.SysError("failed to search tokens: " + err.Error())
return nil, 0, errors.New("搜索令牌失败")
}
return tokens, total, nil
}
func ValidateUserToken(key string) (token *Token, err error) {

View File

@@ -238,7 +238,7 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.Use(middleware.UserAuth())
{
tokenRoute.GET("/", controller.GetAllTokens)
tokenRoute.GET("/search", controller.SearchTokens)
tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
tokenRoute.GET("/:id", controller.GetToken)
tokenRoute.POST("/", controller.AddToken)
tokenRoute.PUT("/", controller.UpdateToken)

View File

@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(strValue, 10, 64)
if err != nil {
continue
// 兼容 float 格式的字符串(如 "2.000000"
floatValue, fErr := strconv.ParseFloat(strValue, 64)
if fErr != nil {
continue
}
intValue = int64(floatValue)
}
field.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(strValue, 10, 64)
if err != nil {
continue
// 兼容 float 格式的字符串
floatValue, fErr := strconv.ParseFloat(strValue, 64)
if fErr != nil || floatValue < 0 {
continue
}
uintValue = uint64(floatValue)
}
field.SetUint(uintValue)
case reflect.Float32, reflect.Float64:

View File

@@ -0,0 +1,28 @@
package operation_setting
import "github.com/QuantumNous/new-api/setting/config"
// TokenSetting 令牌相关配置
type TokenSetting struct {
MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
}
// 默认配置
var tokenSetting = TokenSetting{
MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("token_setting", &tokenSetting)
}
// GetTokenSetting 获取令牌配置
func GetTokenSetting() *TokenSetting {
return &tokenSetting
}
// GetMaxUserTokens 获取每用户最大令牌数量
func GetMaxUserTokens() int {
return GetTokenSetting().MaxUserTokens
}

View File

@@ -78,6 +78,9 @@ const OperationSetting = () => {
'checkin_setting.enabled': false,
'checkin_setting.min_quota': 1000,
'checkin_setting.max_quota': 10000,
/* 令牌设置 */
'token_setting.max_user_tokens': 1000,
});
let [loading, setLoading] = useState(false);

View File

@@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => {
const [tokenCount, setTokenCount] = useState(0);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searching, setSearching] = useState(false);
const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图
// Selection state
const [selectedKeys, setSelectedKeys] = useState([]);
@@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => {
// Load tokens function
const loadTokens = async (page = 1, size = pageSize) => {
setLoading(true);
setSearchMode(false);
const res = await API.get(`/api/token/?p=${page}&size=${size}`);
const { success, message, data } = res.data;
if (success) {
@@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => {
};
// Search tokens function
const searchTokens = async () => {
const searchTokens = async (page = 1, size = pageSize) => {
const { searchKeyword, searchToken } = getFormValues();
if (searchKeyword === '' && searchToken === '') {
setSearchMode(false);
await loadTokens(1);
return;
}
setSearching(true);
const res = await API.get(
`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`,
`/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
);
const { success, message, data } = res.data;
if (success) {
setTokens(data);
setTokenCount(data.length);
setActivePage(1);
setSearchMode(true);
syncPageData(data);
} else {
showError(message);
}
@@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => {
// Page handlers
const handlePageChange = (page) => {
loadTokens(page, pageSize).then();
if (searchMode) {
searchTokens(page, pageSize).then();
} else {
loadTokens(page, pageSize).then();
}
};
const handlePageSizeChange = async (size) => {
setPageSize(size);
await loadTokens(1, size);
if (searchMode) {
await searchTokens(1, size);
} else {
await loadTokens(1, size);
}
};
// Row selection handlers

View File

@@ -56,6 +56,7 @@ export default function GeneralSettings(props) {
DefaultCollapseSidebar: false,
DemoSiteEnabled: false,
SelfUseModeEnabled: false,
'token_setting.max_user_tokens': 1000,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -287,6 +288,19 @@ export default function GeneralSettings(props) {
/>
</Col>
</Row>
<Row gutter={16}>
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
<Form.InputNumber
label={t('用户最大令牌数量')}
field={'token_setting.max_user_tokens'}
step={1}
min={1}
extraText={t('每个用户最多可创建的令牌数量,默认 1000设置过大可能会影响性能')}
placeholder={'1000'}
onChange={handleFieldChange('token_setting.max_user_tokens')}
/>
</Col>
</Row>
<Row>
<Button size='default' onClick={onSubmit}>
{t('保存通用设置')}