From 3e1be18310f35d20742683ca9e4bf3bcafc173c5 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 6 Feb 2026 17:47:34 +0800 Subject: [PATCH] fix: harden token search with pagination, rate limiting and input validation - Add configurable per-user token creation limit (max_user_tokens) - Sanitize search input patterns to prevent expensive queries - Add per-user search rate limiting (by user ID) - Add pagination to search endpoint with strict page size cap - Skip empty search fields instead of matching nothing - Hide internal errors from API responses - Fix Interface2String float64 formatting causing config parse failures - Add float-string fallback in config system for int/uint fields --- common/constants.go | 4 + common/utils.go | 2 +- controller/token.go | 28 ++++-- middleware/rate-limit.go | 85 ++++++++++++++++ model/token.go | 98 ++++++++++++++++++- router/api-router.go | 2 +- setting/config/config.go | 14 ++- setting/operation_setting/token_setting.go | 28 ++++++ .../components/settings/OperationSetting.jsx | 3 + web/src/hooks/tokens/useTokensData.jsx | 24 +++-- .../Setting/Operation/SettingsGeneral.jsx | 14 +++ 11 files changed, 282 insertions(+), 20 deletions(-) create mode 100644 setting/operation_setting/token_setting.go diff --git a/common/constants.go b/common/constants.go index 51b798dbc..204f2e8cf 100644 --- a/common/constants.go +++ b/common/constants.go @@ -175,6 +175,10 @@ var ( DownloadRateLimitNum = 10 DownloadRateLimitDuration int64 = 60 + + // Per-user search rate limit (applies after authentication, keyed by user ID) + SearchRateLimitNum = 10 + SearchRateLimitDuration int64 = 60 ) var RateLimitKeyExpirationDuration = 20 * time.Minute diff --git a/common/utils.go b/common/utils.go index b67fe1c5f..3a8be45b3 100644 --- a/common/utils.go +++ b/common/utils.go @@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string { case int: return fmt.Sprintf("%d", inter.(int)) case float64: - return fmt.Sprintf("%f", inter.(float64)) + return strconv.FormatFloat(inter.(float64), 'f', -1, 64) case bool: if inter.(bool) { return "true" diff --git a/controller/token.go b/controller/token.go index c5dc5ec42..21f63665d 100644 --- a/controller/token.go +++ b/controller/token.go @@ -8,6 +8,7 @@ import ( "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/gin-gonic/gin" ) @@ -31,16 +32,17 @@ func SearchTokens(c *gin.Context) { userId := c.GetInt("id") keyword := c.Query("keyword") token := c.Query("token") - tokens, err := model.SearchUserTokens(userId, keyword, token) + + pageInfo := common.GetPageQuery(c) + + tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) if err != nil { common.ApiError(c, err) return } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": tokens, - }) + pageInfo.SetTotal(int(total)) + pageInfo.SetItems(tokens) + common.ApiSuccess(c, pageInfo) return } @@ -168,6 +170,20 @@ func AddToken(c *gin.Context) { return } } + // 检查用户令牌数量是否已达上限 + maxTokens := operation_setting.GetMaxUserTokens() + count, err := model.CountUserTokens(c.GetInt("id")) + if err != nil { + common.ApiError(c, err) + return + } + if int(count) >= maxTokens { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens), + }) + return + } key, err := common.GenerateKey() if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 866542e17..10d7d8217 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) { func UploadRateLimit() func(c *gin.Context) { return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") } + +// userRateLimitFactory creates a rate limiter keyed by authenticated user ID +// instead of client IP, making it resistant to proxy rotation attacks. +// Must be used AFTER authentication middleware (UserAuth). +func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { + if common.RedisEnabled { + return func(c *gin.Context) { + userId := c.GetInt("id") + if userId == 0 { + c.Status(http.StatusUnauthorized) + c.Abort() + return + } + key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId) + userRedisRateLimiter(c, maxRequestNum, duration, key) + } + } + // It's safe to call multi times. + inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + return func(c *gin.Context) { + userId := c.GetInt("id") + if userId == 0 { + c.Status(http.StatusUnauthorized) + c.Abort() + return + } + key := fmt.Sprintf("%s:user:%d", mark, userId) + if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } + } +} + +// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key +// (to support user-ID-based keys). +func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) { + ctx := context.Background() + rdb := common.RDB + listLength, err := rdb.LLen(ctx, key).Result() + if err != nil { + fmt.Println(err.Error()) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + if listLength < int64(maxRequestNum) { + rdb.LPush(ctx, key, time.Now().Format(timeFormat)) + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + } else { + oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() + oldTime, err := time.Parse(timeFormat, oldTimeStr) + if err != nil { + fmt.Println(err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + nowTimeStr := time.Now().Format(timeFormat) + nowTime, err := time.Parse(timeFormat, nowTimeStr) + if err != nil { + fmt.Println(err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + if int64(nowTime.Sub(oldTime).Seconds()) < duration { + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + c.Status(http.StatusTooManyRequests) + c.Abort() + return + } else { + rdb.LPush(ctx, key, time.Now().Format(timeFormat)) + rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) + rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + } + } +} + +// SearchRateLimit returns a per-user rate limiter for search endpoints. +// 10 requests per 60 seconds per user (by user ID, not IP). +func SearchRateLimit() func(c *gin.Context) { + return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR") +} diff --git a/model/token.go b/model/token.go index b68fc0cfb..ab3804b20 100644 --- a/model/token.go +++ b/model/token.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/setting/operation_setting" "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" ) @@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { return tokens, err } -func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) { +// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。 +// 规则: +// 1. 转义 _ 和 \(不允许 _ 作通配符) +// 2. 连续的 % 合并为单个 % +// 3. 最多允许 2 个 % +// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2 +// 5. 不含 % 时按精确匹配 +func sanitizeLikePattern(input string) (string, error) { + // 1. 转义 \ 和 _ + input = strings.ReplaceAll(input, `\`, `\\`) + input = strings.ReplaceAll(input, `_`, `\_`) + + // 2. 连续的 % 直接拒绝 + if strings.Contains(input, "%%") { + return "", errors.New("搜索模式中不允许包含连续的 % 通配符") + } + + // 3. 统计 % 数量,不得超过 2 + count := strings.Count(input, "%") + if count > 2 { + return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符") + } + + // 4. 含 % 时,去掉 % 后关键词长度必须 >= 2 + if count > 0 { + stripped := strings.ReplaceAll(input, "%", "") + if len(stripped) < 2 { + return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符") + } + return input, nil + } + + // 5. 无 % 时,精确全匹配 + return input, nil +} + +const searchHardLimit = 100 + +func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) { + // model 层强制截断 + if limit <= 0 || limit > searchHardLimit { + limit = searchHardLimit + } + if offset < 0 { + offset = 0 + } + if token != "" { token = strings.Trim(token, "sk-") } - err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error - return tokens, err + + // 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索 + maxTokens := operation_setting.GetMaxUserTokens() + hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%") + if hasFuzzy { + count, err := CountUserTokens(userId) + if err != nil { + common.SysLog("failed to count user tokens: " + err.Error()) + return nil, 0, errors.New("获取令牌数量失败") + } + if int(count) > maxTokens { + return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符") + } + } + + baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId) + + // 非空才加 LIKE 条件,空则跳过(不过滤该字段) + if keyword != "" { + keywordPattern, err := sanitizeLikePattern(keyword) + if err != nil { + return nil, 0, err + } + baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern) + } + if token != "" { + tokenPattern, err := sanitizeLikePattern(token) + if err != nil { + return nil, 0, err + } + baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern) + } + + // 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT) + err = baseQuery.Limit(maxTokens).Count(&total).Error + if err != nil { + common.SysError("failed to count search tokens: " + err.Error()) + return nil, 0, errors.New("搜索令牌失败") + } + + // 再分页查数据 + err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error + if err != nil { + common.SysError("failed to search tokens: " + err.Error()) + return nil, 0, errors.New("搜索令牌失败") + } + return tokens, total, nil } func ValidateUserToken(key string) (token *Token, err error) { diff --git a/router/api-router.go b/router/api-router.go index 973684958..bc926be50 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -186,7 +186,7 @@ func SetApiRouter(router *gin.Engine) { tokenRoute.Use(middleware.UserAuth()) { tokenRoute.GET("/", controller.GetAllTokens) - tokenRoute.GET("/search", controller.SearchTokens) + tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens) tokenRoute.GET("/:id", controller.GetToken) tokenRoute.POST("/", controller.AddToken) tokenRoute.PUT("/", controller.UpdateToken) diff --git a/setting/config/config.go b/setting/config/config.go index 6c6abe9d4..8b3d05139 100644 --- a/setting/config/config.go +++ b/setting/config/config.go @@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: intValue, err := strconv.ParseInt(strValue, 10, 64) if err != nil { - continue + // 兼容 float 格式的字符串(如 "2.000000") + floatValue, fErr := strconv.ParseFloat(strValue, 64) + if fErr != nil { + continue + } + intValue = int64(floatValue) } field.SetInt(intValue) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: uintValue, err := strconv.ParseUint(strValue, 10, 64) if err != nil { - continue + // 兼容 float 格式的字符串 + floatValue, fErr := strconv.ParseFloat(strValue, 64) + if fErr != nil || floatValue < 0 { + continue + } + uintValue = uint64(floatValue) } field.SetUint(uintValue) case reflect.Float32, reflect.Float64: diff --git a/setting/operation_setting/token_setting.go b/setting/operation_setting/token_setting.go new file mode 100644 index 000000000..0d4c4e2f2 --- /dev/null +++ b/setting/operation_setting/token_setting.go @@ -0,0 +1,28 @@ +package operation_setting + +import "github.com/QuantumNous/new-api/setting/config" + +// TokenSetting 令牌相关配置 +type TokenSetting struct { + MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量 +} + +// 默认配置 +var tokenSetting = TokenSetting{ + MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌 +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("token_setting", &tokenSetting) +} + +// GetTokenSetting 获取令牌配置 +func GetTokenSetting() *TokenSetting { + return &tokenSetting +} + +// GetMaxUserTokens 获取每用户最大令牌数量 +func GetMaxUserTokens() int { + return GetTokenSetting().MaxUserTokens +} diff --git a/web/src/components/settings/OperationSetting.jsx b/web/src/components/settings/OperationSetting.jsx index 9ee5fd007..0a3479009 100644 --- a/web/src/components/settings/OperationSetting.jsx +++ b/web/src/components/settings/OperationSetting.jsx @@ -77,6 +77,9 @@ const OperationSetting = () => { 'checkin_setting.enabled': false, 'checkin_setting.min_quota': 1000, 'checkin_setting.max_quota': 10000, + + /* 令牌设置 */ + 'token_setting.max_user_tokens': 1000, }); let [loading, setLoading] = useState(false); diff --git a/web/src/hooks/tokens/useTokensData.jsx b/web/src/hooks/tokens/useTokensData.jsx index a34508f49..35729c015 100644 --- a/web/src/hooks/tokens/useTokensData.jsx +++ b/web/src/hooks/tokens/useTokensData.jsx @@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => { const [tokenCount, setTokenCount] = useState(0); const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); const [searching, setSearching] = useState(false); + const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图 // Selection state const [selectedKeys, setSelectedKeys] = useState([]); @@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => { // Load tokens function const loadTokens = async (page = 1, size = pageSize) => { setLoading(true); + setSearchMode(false); const res = await API.get(`/api/token/?p=${page}&size=${size}`); const { success, message, data } = res.data; if (success) { @@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => { }; // Search tokens function - const searchTokens = async () => { + const searchTokens = async (page = 1, size = pageSize) => { const { searchKeyword, searchToken } = getFormValues(); if (searchKeyword === '' && searchToken === '') { + setSearchMode(false); await loadTokens(1); return; } setSearching(true); const res = await API.get( - `/api/token/search?keyword=${searchKeyword}&token=${searchToken}`, + `/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`, ); const { success, message, data } = res.data; if (success) { - setTokens(data); - setTokenCount(data.length); - setActivePage(1); + setSearchMode(true); + syncPageData(data); } else { showError(message); } @@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => { // Page handlers const handlePageChange = (page) => { - loadTokens(page, pageSize).then(); + if (searchMode) { + searchTokens(page, pageSize).then(); + } else { + loadTokens(page, pageSize).then(); + } }; const handlePageSizeChange = async (size) => { setPageSize(size); - await loadTokens(1, size); + if (searchMode) { + await searchTokens(1, size); + } else { + await loadTokens(1, size); + } }; // Row selection handlers diff --git a/web/src/pages/Setting/Operation/SettingsGeneral.jsx b/web/src/pages/Setting/Operation/SettingsGeneral.jsx index fbfa0ed99..8b9a621da 100644 --- a/web/src/pages/Setting/Operation/SettingsGeneral.jsx +++ b/web/src/pages/Setting/Operation/SettingsGeneral.jsx @@ -56,6 +56,7 @@ export default function GeneralSettings(props) { DefaultCollapseSidebar: false, DemoSiteEnabled: false, SelfUseModeEnabled: false, + 'token_setting.max_user_tokens': 1000, }); const refForm = useRef(); const [inputsRow, setInputsRow] = useState(inputs); @@ -287,6 +288,19 @@ export default function GeneralSettings(props) { /> + + + + +