mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 00:27:02 +00:00
Merge commit from fork
fix: harden token search with pagination, rate limiting and input validation
This commit is contained in:
@@ -175,6 +175,10 @@ var (
|
||||
|
||||
DownloadRateLimitNum = 10
|
||||
DownloadRateLimitDuration int64 = 60
|
||||
|
||||
// Per-user search rate limit (applies after authentication, keyed by user ID)
|
||||
SearchRateLimitNum = 10
|
||||
SearchRateLimitDuration int64 = 60
|
||||
)
|
||||
|
||||
var RateLimitKeyExpirationDuration = 20 * time.Minute
|
||||
|
||||
@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
|
||||
case int:
|
||||
return fmt.Sprintf("%d", inter.(int))
|
||||
case float64:
|
||||
return fmt.Sprintf("%f", inter.(float64))
|
||||
return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
|
||||
case bool:
|
||||
if inter.(bool) {
|
||||
return "true"
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/i18n"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -31,16 +32,17 @@ func SearchTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
keyword := c.Query("keyword")
|
||||
token := c.Query("token")
|
||||
tokens, err := model.SearchUserTokens(userId, keyword, token)
|
||||
|
||||
pageInfo := common.GetPageQuery(c)
|
||||
|
||||
tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
})
|
||||
pageInfo.SetTotal(int(total))
|
||||
pageInfo.SetItems(tokens)
|
||||
common.ApiSuccess(c, pageInfo)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,6 +159,20 @@ func AddToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// 检查用户令牌数量是否已达上限
|
||||
maxTokens := operation_setting.GetMaxUserTokens()
|
||||
count, err := model.CountUserTokens(c.GetInt("id"))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
if int(count) >= maxTokens {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
|
||||
})
|
||||
return
|
||||
}
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
common.ApiErrorI18n(c, i18n.MsgTokenGenerateFailed)
|
||||
|
||||
@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
|
||||
func UploadRateLimit() func(c *gin.Context) {
|
||||
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
|
||||
}
|
||||
|
||||
// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
|
||||
// instead of client IP, making it resistant to proxy rotation attacks.
|
||||
// Must be used AFTER authentication middleware (UserAuth).
|
||||
func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
|
||||
if common.RedisEnabled {
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
|
||||
userRedisRateLimiter(c, maxRequestNum, duration, key)
|
||||
}
|
||||
}
|
||||
// It's safe to call multi times.
|
||||
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
c.Status(http.StatusUnauthorized)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
key := fmt.Sprintf("%s:user:%d", mark, userId)
|
||||
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
|
||||
// (to support user-ID-based keys).
|
||||
func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
|
||||
ctx := context.Background()
|
||||
rdb := common.RDB
|
||||
listLength, err := rdb.LLen(ctx, key).Result()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if listLength < int64(maxRequestNum) {
|
||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
} else {
|
||||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
||||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
nowTimeStr := time.Now().Format(timeFormat)
|
||||
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
c.Status(http.StatusInternalServerError)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
c.Status(http.StatusTooManyRequests)
|
||||
c.Abort()
|
||||
return
|
||||
} else {
|
||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
||||
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SearchRateLimit returns a per-user rate limiter for search endpoints.
|
||||
// 10 requests per 60 seconds per user (by user ID, not IP).
|
||||
func SearchRateLimit() func(c *gin.Context) {
|
||||
return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
|
||||
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
|
||||
// 规则:
|
||||
// 1. 转义 _ 和 \(不允许 _ 作通配符)
|
||||
// 2. 连续的 % 合并为单个 %
|
||||
// 3. 最多允许 2 个 %
|
||||
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
|
||||
// 5. 不含 % 时按精确匹配
|
||||
func sanitizeLikePattern(input string) (string, error) {
|
||||
// 1. 转义 \ 和 _
|
||||
input = strings.ReplaceAll(input, `\`, `\\`)
|
||||
input = strings.ReplaceAll(input, `_`, `\_`)
|
||||
|
||||
// 2. 连续的 % 直接拒绝
|
||||
if strings.Contains(input, "%%") {
|
||||
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
|
||||
}
|
||||
|
||||
// 3. 统计 % 数量,不得超过 2
|
||||
count := strings.Count(input, "%")
|
||||
if count > 2 {
|
||||
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
|
||||
}
|
||||
|
||||
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
|
||||
if count > 0 {
|
||||
stripped := strings.ReplaceAll(input, "%", "")
|
||||
if len(stripped) < 2 {
|
||||
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
|
||||
}
|
||||
return input, nil
|
||||
}
|
||||
|
||||
// 5. 无 % 时,精确全匹配
|
||||
return input, nil
|
||||
}
|
||||
|
||||
const searchHardLimit = 100
|
||||
|
||||
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
|
||||
// model 层强制截断
|
||||
if limit <= 0 || limit > searchHardLimit {
|
||||
limit = searchHardLimit
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
token = strings.Trim(token, "sk-")
|
||||
}
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||
return tokens, err
|
||||
|
||||
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
|
||||
maxTokens := operation_setting.GetMaxUserTokens()
|
||||
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
|
||||
if hasFuzzy {
|
||||
count, err := CountUserTokens(userId)
|
||||
if err != nil {
|
||||
common.SysLog("failed to count user tokens: " + err.Error())
|
||||
return nil, 0, errors.New("获取令牌数量失败")
|
||||
}
|
||||
if int(count) > maxTokens {
|
||||
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
|
||||
}
|
||||
}
|
||||
|
||||
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
|
||||
|
||||
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
|
||||
if keyword != "" {
|
||||
keywordPattern, err := sanitizeLikePattern(keyword)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
|
||||
}
|
||||
if token != "" {
|
||||
tokenPattern, err := sanitizeLikePattern(token)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
|
||||
}
|
||||
|
||||
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
|
||||
err = baseQuery.Limit(maxTokens).Count(&total).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to count search tokens: " + err.Error())
|
||||
return nil, 0, errors.New("搜索令牌失败")
|
||||
}
|
||||
|
||||
// 再分页查数据
|
||||
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to search tokens: " + err.Error())
|
||||
return nil, 0, errors.New("搜索令牌失败")
|
||||
}
|
||||
return tokens, total, nil
|
||||
}
|
||||
|
||||
func ValidateUserToken(key string) (token *Token, err error) {
|
||||
|
||||
@@ -238,7 +238,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
tokenRoute.Use(middleware.UserAuth())
|
||||
{
|
||||
tokenRoute.GET("/", controller.GetAllTokens)
|
||||
tokenRoute.GET("/search", controller.SearchTokens)
|
||||
tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
|
||||
tokenRoute.GET("/:id", controller.GetToken)
|
||||
tokenRoute.POST("/", controller.AddToken)
|
||||
tokenRoute.PUT("/", controller.UpdateToken)
|
||||
|
||||
@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(strValue, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
// 兼容 float 格式的字符串(如 "2.000000")
|
||||
floatValue, fErr := strconv.ParseFloat(strValue, 64)
|
||||
if fErr != nil {
|
||||
continue
|
||||
}
|
||||
intValue = int64(floatValue)
|
||||
}
|
||||
field.SetInt(intValue)
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
uintValue, err := strconv.ParseUint(strValue, 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
// 兼容 float 格式的字符串
|
||||
floatValue, fErr := strconv.ParseFloat(strValue, 64)
|
||||
if fErr != nil || floatValue < 0 {
|
||||
continue
|
||||
}
|
||||
uintValue = uint64(floatValue)
|
||||
}
|
||||
field.SetUint(uintValue)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
|
||||
28
setting/operation_setting/token_setting.go
Normal file
28
setting/operation_setting/token_setting.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package operation_setting
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
// TokenSetting 令牌相关配置
|
||||
type TokenSetting struct {
|
||||
MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var tokenSetting = TokenSetting{
|
||||
MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("token_setting", &tokenSetting)
|
||||
}
|
||||
|
||||
// GetTokenSetting 获取令牌配置
|
||||
func GetTokenSetting() *TokenSetting {
|
||||
return &tokenSetting
|
||||
}
|
||||
|
||||
// GetMaxUserTokens 获取每用户最大令牌数量
|
||||
func GetMaxUserTokens() int {
|
||||
return GetTokenSetting().MaxUserTokens
|
||||
}
|
||||
@@ -78,6 +78,9 @@ const OperationSetting = () => {
|
||||
'checkin_setting.enabled': false,
|
||||
'checkin_setting.min_quota': 1000,
|
||||
'checkin_setting.max_quota': 10000,
|
||||
|
||||
/* 令牌设置 */
|
||||
'token_setting.max_user_tokens': 1000,
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
@@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => {
|
||||
const [tokenCount, setTokenCount] = useState(0);
|
||||
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
||||
const [searching, setSearching] = useState(false);
|
||||
const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图
|
||||
|
||||
// Selection state
|
||||
const [selectedKeys, setSelectedKeys] = useState([]);
|
||||
@@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => {
|
||||
// Load tokens function
|
||||
const loadTokens = async (page = 1, size = pageSize) => {
|
||||
setLoading(true);
|
||||
setSearchMode(false);
|
||||
const res = await API.get(`/api/token/?p=${page}&size=${size}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
@@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => {
|
||||
};
|
||||
|
||||
// Search tokens function
|
||||
const searchTokens = async () => {
|
||||
const searchTokens = async (page = 1, size = pageSize) => {
|
||||
const { searchKeyword, searchToken } = getFormValues();
|
||||
if (searchKeyword === '' && searchToken === '') {
|
||||
setSearchMode(false);
|
||||
await loadTokens(1);
|
||||
return;
|
||||
}
|
||||
setSearching(true);
|
||||
const res = await API.get(
|
||||
`/api/token/search?keyword=${searchKeyword}&token=${searchToken}`,
|
||||
`/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setTokens(data);
|
||||
setTokenCount(data.length);
|
||||
setActivePage(1);
|
||||
setSearchMode(true);
|
||||
syncPageData(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => {
|
||||
|
||||
// Page handlers
|
||||
const handlePageChange = (page) => {
|
||||
loadTokens(page, pageSize).then();
|
||||
if (searchMode) {
|
||||
searchTokens(page, pageSize).then();
|
||||
} else {
|
||||
loadTokens(page, pageSize).then();
|
||||
}
|
||||
};
|
||||
|
||||
const handlePageSizeChange = async (size) => {
|
||||
setPageSize(size);
|
||||
await loadTokens(1, size);
|
||||
if (searchMode) {
|
||||
await searchTokens(1, size);
|
||||
} else {
|
||||
await loadTokens(1, size);
|
||||
}
|
||||
};
|
||||
|
||||
// Row selection handlers
|
||||
|
||||
@@ -56,6 +56,7 @@ export default function GeneralSettings(props) {
|
||||
DefaultCollapseSidebar: false,
|
||||
DemoSiteEnabled: false,
|
||||
SelfUseModeEnabled: false,
|
||||
'token_setting.max_user_tokens': 1000,
|
||||
});
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
@@ -287,6 +288,19 @@ export default function GeneralSettings(props) {
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row gutter={16}>
|
||||
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
|
||||
<Form.InputNumber
|
||||
label={t('用户最大令牌数量')}
|
||||
field={'token_setting.max_user_tokens'}
|
||||
step={1}
|
||||
min={1}
|
||||
extraText={t('每个用户最多可创建的令牌数量,默认 1000,设置过大可能会影响性能')}
|
||||
placeholder={'1000'}
|
||||
onChange={handleFieldChange('token_setting.max_user_tokens')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Button size='default' onClick={onSubmit}>
|
||||
{t('保存通用设置')}
|
||||
|
||||
Reference in New Issue
Block a user