diff --git a/common/constants.go b/common/constants.go
index 51b798dbc..204f2e8cf 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -175,6 +175,10 @@ var (
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
+
+ // Per-user search rate limit (applies after authentication, keyed by user ID)
+ SearchRateLimitNum = 10
+ SearchRateLimitDuration int64 = 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
diff --git a/common/utils.go b/common/utils.go
index b67fe1c5f..3a8be45b3 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
- return fmt.Sprintf("%f", inter.(float64))
+ return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
case bool:
if inter.(bool) {
return "true"
diff --git a/controller/token.go b/controller/token.go
index c5dc5ec42..21f63665d 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -8,6 +8,7 @@ import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
+ "github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/gin-gonic/gin"
)
@@ -31,16 +32,17 @@ func SearchTokens(c *gin.Context) {
userId := c.GetInt("id")
keyword := c.Query("keyword")
token := c.Query("token")
- tokens, err := model.SearchUserTokens(userId, keyword, token)
+
+ pageInfo := common.GetPageQuery(c)
+
+ tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
common.ApiError(c, err)
return
}
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": tokens,
- })
+ pageInfo.SetTotal(int(total))
+ pageInfo.SetItems(tokens)
+ common.ApiSuccess(c, pageInfo)
return
}
@@ -168,6 +170,20 @@ func AddToken(c *gin.Context) {
return
}
}
+ // 检查用户令牌数量是否已达上限
+ maxTokens := operation_setting.GetMaxUserTokens()
+ count, err := model.CountUserTokens(c.GetInt("id"))
+ if err != nil {
+ common.ApiError(c, err)
+ return
+ }
+ if int(count) >= maxTokens {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
+ })
+ return
+ }
key, err := common.GenerateKey()
if err != nil {
c.JSON(http.StatusOK, gin.H{
diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go
index 866542e17..10d7d8217 100644
--- a/middleware/rate-limit.go
+++ b/middleware/rate-limit.go
@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
}
+
+// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
+// instead of client IP, making it resistant to proxy rotation attacks.
+// Must be used AFTER authentication middleware (UserAuth).
+func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
+ if common.RedisEnabled {
+ return func(c *gin.Context) {
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.Status(http.StatusUnauthorized)
+ c.Abort()
+ return
+ }
+ key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
+ userRedisRateLimiter(c, maxRequestNum, duration, key)
+ }
+ }
+ // It's safe to call multi times.
+ inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+ return func(c *gin.Context) {
+ userId := c.GetInt("id")
+ if userId == 0 {
+ c.Status(http.StatusUnauthorized)
+ c.Abort()
+ return
+ }
+ key := fmt.Sprintf("%s:user:%d", mark, userId)
+ if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ }
+ }
+}
+
+// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
+// (to support user-ID-based keys).
+func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
+ ctx := context.Background()
+ rdb := common.RDB
+ listLength, err := rdb.LLen(ctx, key).Result()
+ if err != nil {
+ fmt.Println(err.Error())
+ c.Status(http.StatusInternalServerError)
+ c.Abort()
+ return
+ }
+ if listLength < int64(maxRequestNum) {
+ rdb.LPush(ctx, key, time.Now().Format(timeFormat))
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ } else {
+ oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
+ oldTime, err := time.Parse(timeFormat, oldTimeStr)
+ if err != nil {
+ fmt.Println(err)
+ c.Status(http.StatusInternalServerError)
+ c.Abort()
+ return
+ }
+ nowTimeStr := time.Now().Format(timeFormat)
+ nowTime, err := time.Parse(timeFormat, nowTimeStr)
+ if err != nil {
+ fmt.Println(err)
+ c.Status(http.StatusInternalServerError)
+ c.Abort()
+ return
+ }
+ if int64(nowTime.Sub(oldTime).Seconds()) < duration {
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ c.Status(http.StatusTooManyRequests)
+ c.Abort()
+ return
+ } else {
+ rdb.LPush(ctx, key, time.Now().Format(timeFormat))
+ rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
+ rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+ }
+ }
+}
+
+// SearchRateLimit returns a per-user rate limiter for search endpoints.
+// 10 requests per 60 seconds per user (by user ID, not IP).
+func SearchRateLimit() func(c *gin.Context) {
+ return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
+}
diff --git a/model/token.go b/model/token.go
index b68fc0cfb..ab3804b20 100644
--- a/model/token.go
+++ b/model/token.go
@@ -6,6 +6,7 @@ import (
"strings"
"github.com/QuantumNous/new-api/common"
+ "github.com/QuantumNous/new-api/setting/operation_setting"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
return tokens, err
}
-func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
+// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
+// 规则:
+// 1. 转义 _ 和 \(不允许 _ 作通配符)
+// 2. 连续的 % 合并为单个 %
+// 3. 最多允许 2 个 %
+// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
+// 5. 不含 % 时按精确匹配
+func sanitizeLikePattern(input string) (string, error) {
+ // 1. 转义 \ 和 _
+ input = strings.ReplaceAll(input, `\`, `\\`)
+ input = strings.ReplaceAll(input, `_`, `\_`)
+
+ // 2. 连续的 % 直接拒绝
+ if strings.Contains(input, "%%") {
+ return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
+ }
+
+ // 3. 统计 % 数量,不得超过 2
+ count := strings.Count(input, "%")
+ if count > 2 {
+ return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
+ }
+
+ // 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
+ if count > 0 {
+ stripped := strings.ReplaceAll(input, "%", "")
+ if len(stripped) < 2 {
+ return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
+ }
+ return input, nil
+ }
+
+ // 5. 无 % 时,精确全匹配
+ return input, nil
+}
+
+const searchHardLimit = 100
+
+func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
+ // model 层强制截断
+ if limit <= 0 || limit > searchHardLimit {
+ limit = searchHardLimit
+ }
+ if offset < 0 {
+ offset = 0
+ }
+
if token != "" {
token = strings.Trim(token, "sk-")
}
- err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
- return tokens, err
+
+ // 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
+ maxTokens := operation_setting.GetMaxUserTokens()
+ hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
+ if hasFuzzy {
+ count, err := CountUserTokens(userId)
+ if err != nil {
+ common.SysLog("failed to count user tokens: " + err.Error())
+ return nil, 0, errors.New("获取令牌数量失败")
+ }
+ if int(count) > maxTokens {
+ return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
+ }
+ }
+
+ baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
+
+ // 非空才加 LIKE 条件,空则跳过(不过滤该字段)
+ if keyword != "" {
+ keywordPattern, err := sanitizeLikePattern(keyword)
+ if err != nil {
+ return nil, 0, err
+ }
+ baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
+ }
+ if token != "" {
+ tokenPattern, err := sanitizeLikePattern(token)
+ if err != nil {
+ return nil, 0, err
+ }
+ baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
+ }
+
+ // 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
+ err = baseQuery.Limit(maxTokens).Count(&total).Error
+ if err != nil {
+ common.SysError("failed to count search tokens: " + err.Error())
+ return nil, 0, errors.New("搜索令牌失败")
+ }
+
+ // 再分页查数据
+ err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
+ if err != nil {
+ common.SysError("failed to search tokens: " + err.Error())
+ return nil, 0, errors.New("搜索令牌失败")
+ }
+ return tokens, total, nil
}
func ValidateUserToken(key string) (token *Token, err error) {
diff --git a/router/api-router.go b/router/api-router.go
index 973684958..bc926be50 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -186,7 +186,7 @@ func SetApiRouter(router *gin.Engine) {
tokenRoute.Use(middleware.UserAuth())
{
tokenRoute.GET("/", controller.GetAllTokens)
- tokenRoute.GET("/search", controller.SearchTokens)
+ tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
tokenRoute.GET("/:id", controller.GetToken)
tokenRoute.POST("/", controller.AddToken)
tokenRoute.PUT("/", controller.UpdateToken)
diff --git a/setting/config/config.go b/setting/config/config.go
index 6c6abe9d4..8b3d05139 100644
--- a/setting/config/config.go
+++ b/setting/config/config.go
@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, err := strconv.ParseInt(strValue, 10, 64)
if err != nil {
- continue
+ // 兼容 float 格式的字符串(如 "2.000000")
+ floatValue, fErr := strconv.ParseFloat(strValue, 64)
+ if fErr != nil {
+ continue
+ }
+ intValue = int64(floatValue)
}
field.SetInt(intValue)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, err := strconv.ParseUint(strValue, 10, 64)
if err != nil {
- continue
+ // 兼容 float 格式的字符串
+ floatValue, fErr := strconv.ParseFloat(strValue, 64)
+ if fErr != nil || floatValue < 0 {
+ continue
+ }
+ uintValue = uint64(floatValue)
}
field.SetUint(uintValue)
case reflect.Float32, reflect.Float64:
diff --git a/setting/operation_setting/token_setting.go b/setting/operation_setting/token_setting.go
new file mode 100644
index 000000000..0d4c4e2f2
--- /dev/null
+++ b/setting/operation_setting/token_setting.go
@@ -0,0 +1,28 @@
+package operation_setting
+
+import "github.com/QuantumNous/new-api/setting/config"
+
+// TokenSetting 令牌相关配置
+type TokenSetting struct {
+ MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
+}
+
+// 默认配置
+var tokenSetting = TokenSetting{
+ MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
+}
+
+func init() {
+ // 注册到全局配置管理器
+ config.GlobalConfig.Register("token_setting", &tokenSetting)
+}
+
+// GetTokenSetting 获取令牌配置
+func GetTokenSetting() *TokenSetting {
+ return &tokenSetting
+}
+
+// GetMaxUserTokens 获取每用户最大令牌数量
+func GetMaxUserTokens() int {
+ return GetTokenSetting().MaxUserTokens
+}
diff --git a/web/src/components/settings/OperationSetting.jsx b/web/src/components/settings/OperationSetting.jsx
index 9ee5fd007..0a3479009 100644
--- a/web/src/components/settings/OperationSetting.jsx
+++ b/web/src/components/settings/OperationSetting.jsx
@@ -77,6 +77,9 @@ const OperationSetting = () => {
'checkin_setting.enabled': false,
'checkin_setting.min_quota': 1000,
'checkin_setting.max_quota': 10000,
+
+ /* 令牌设置 */
+ 'token_setting.max_user_tokens': 1000,
});
let [loading, setLoading] = useState(false);
diff --git a/web/src/hooks/tokens/useTokensData.jsx b/web/src/hooks/tokens/useTokensData.jsx
index a34508f49..35729c015 100644
--- a/web/src/hooks/tokens/useTokensData.jsx
+++ b/web/src/hooks/tokens/useTokensData.jsx
@@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => {
const [tokenCount, setTokenCount] = useState(0);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searching, setSearching] = useState(false);
+ const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图
// Selection state
const [selectedKeys, setSelectedKeys] = useState([]);
@@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => {
// Load tokens function
const loadTokens = async (page = 1, size = pageSize) => {
setLoading(true);
+ setSearchMode(false);
const res = await API.get(`/api/token/?p=${page}&size=${size}`);
const { success, message, data } = res.data;
if (success) {
@@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => {
};
// Search tokens function
- const searchTokens = async () => {
+ const searchTokens = async (page = 1, size = pageSize) => {
const { searchKeyword, searchToken } = getFormValues();
if (searchKeyword === '' && searchToken === '') {
+ setSearchMode(false);
await loadTokens(1);
return;
}
setSearching(true);
const res = await API.get(
- `/api/token/search?keyword=${searchKeyword}&token=${searchToken}`,
+ `/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
);
const { success, message, data } = res.data;
if (success) {
- setTokens(data);
- setTokenCount(data.length);
- setActivePage(1);
+ setSearchMode(true);
+ syncPageData(data);
} else {
showError(message);
}
@@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => {
// Page handlers
const handlePageChange = (page) => {
- loadTokens(page, pageSize).then();
+ if (searchMode) {
+ searchTokens(page, pageSize).then();
+ } else {
+ loadTokens(page, pageSize).then();
+ }
};
const handlePageSizeChange = async (size) => {
setPageSize(size);
- await loadTokens(1, size);
+ if (searchMode) {
+ await searchTokens(1, size);
+ } else {
+ await loadTokens(1, size);
+ }
};
// Row selection handlers
diff --git a/web/src/pages/Setting/Operation/SettingsGeneral.jsx b/web/src/pages/Setting/Operation/SettingsGeneral.jsx
index fbfa0ed99..8b9a621da 100644
--- a/web/src/pages/Setting/Operation/SettingsGeneral.jsx
+++ b/web/src/pages/Setting/Operation/SettingsGeneral.jsx
@@ -56,6 +56,7 @@ export default function GeneralSettings(props) {
DefaultCollapseSidebar: false,
DemoSiteEnabled: false,
SelfUseModeEnabled: false,
+ 'token_setting.max_user_tokens': 1000,
});
const refForm = useRef();
const [inputsRow, setInputsRow] = useState(inputs);
@@ -287,6 +288,19 @@ export default function GeneralSettings(props) {
/>
+
+
+
+
+