From d814d62e2f33b0d9b8b0c4c44fa73b3de2da4de1 Mon Sep 17 00:00:00 2001 From: CaIon Date: Fri, 6 Feb 2026 21:26:26 +0800 Subject: [PATCH] refactor: enhance API security with read-only token authentication and improved rate limiting --- common/constants.go | 2 +- controller/log.go | 50 +++++++++--------- controller/secure_verification.go | 88 ------------------------------- middleware/auth.go | 57 ++++++++++++++++++++ model/log.go | 64 +++++++++++----------- router/api-router.go | 11 ++-- 6 files changed, 119 insertions(+), 153 deletions(-) diff --git a/common/constants.go b/common/constants.go index 204f2e8cf..6823b2c81 100644 --- a/common/constants.go +++ b/common/constants.go @@ -39,7 +39,7 @@ var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex var ItemsPerPage = 10 -var MaxRecentItems = 100 +var MaxRecentItems = 1000 var PasswordLoginEnabled = true var PasswordRegisterEnabled = true diff --git a/controller/log.go b/controller/log.go index 1b2068b6c..cf3825f16 100644 --- a/controller/log.go +++ b/controller/log.go @@ -53,40 +53,32 @@ func GetUserLogs(c *gin.Context) { return } +// Deprecated: SearchAllLogs 已废弃,前端未使用该接口。 func SearchAllLogs(c *gin.Context) { - keyword := c.Query("keyword") - logs, err := model.SearchAllLogs(keyword) - if err != nil { - common.ApiError(c, err) - return - } c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": logs, + "success": false, + "message": "该接口已废弃", }) - return } +// Deprecated: SearchUserLogs 已废弃,前端未使用该接口。 func SearchUserLogs(c *gin.Context) { - keyword := c.Query("keyword") - userId := c.GetInt("id") - logs, err := model.SearchUserLogs(userId, keyword) - if err != nil { - common.ApiError(c, err) - return - } c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": logs, + "success": false, + "message": "该接口已废弃", }) - return } func GetLogByKey(c *gin.Context) { - key := c.Query("key") - logs, err := model.GetLogByKey(key) + tokenId := c.GetInt("token_id") + if tokenId == 0 { + c.JSON(200, gin.H{ + "success": false, + "message": "无效的令牌", + }) + return + } + logs, err := model.GetLogByTokenId(tokenId) if err != nil { c.JSON(200, gin.H{ "success": false, @@ -110,7 +102,11 @@ func GetLogsStat(c *gin.Context) { modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) group := c.Query("group") - stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) + stat, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) + if err != nil { + common.ApiError(c, err) + return + } //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") c.JSON(http.StatusOK, gin.H{ "success": true, @@ -133,7 +129,11 @@ func GetLogsSelfStat(c *gin.Context) { modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) group := c.Query("group") - quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) + quotaNum, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) + if err != nil { + common.ApiError(c, err) + return + } //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) c.JSON(200, gin.H{ "success": true, diff --git a/controller/secure_verification.go b/controller/secure_verification.go index f30c259e6..ad1a615ea 100644 --- a/controller/secure_verification.go +++ b/controller/secure_verification.go @@ -133,94 +133,6 @@ func UniversalVerify(c *gin.Context) { }) } -// GetVerificationStatus 获取验证状态 -func GetVerificationStatus(c *gin.Context) { - userId := c.GetInt("id") - if userId == 0 { - c.JSON(http.StatusUnauthorized, gin.H{ - "success": false, - "message": "未登录", - }) - return - } - - session := sessions.Default(c) - verifiedAtRaw := session.Get(SecureVerificationSessionKey) - - if verifiedAtRaw == nil { - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": VerificationStatusResponse{ - Verified: false, - }, - }) - return - } - - verifiedAt, ok := verifiedAtRaw.(int64) - if !ok { - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": VerificationStatusResponse{ - Verified: false, - }, - }) - return - } - - elapsed := time.Now().Unix() - verifiedAt - if elapsed >= SecureVerificationTimeout { - // 验证已过期 - session.Delete(SecureVerificationSessionKey) - _ = session.Save() - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": VerificationStatusResponse{ - Verified: false, - }, - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": VerificationStatusResponse{ - Verified: true, - ExpiresAt: verifiedAt + SecureVerificationTimeout, - }, - }) -} - -// CheckSecureVerification 检查是否已通过安全验证 -// 返回 true 表示验证有效,false 表示需要重新验证 -func CheckSecureVerification(c *gin.Context) bool { - session := sessions.Default(c) - verifiedAtRaw := session.Get(SecureVerificationSessionKey) - - if verifiedAtRaw == nil { - return false - } - - verifiedAt, ok := verifiedAtRaw.(int64) - if !ok { - return false - } - - elapsed := time.Now().Unix() - verifiedAt - if elapsed >= SecureVerificationTimeout { - // 验证已过期,清除 session - session.Delete(SecureVerificationSessionKey) - _ = session.Save() - return false - } - - return true -} - // PasskeyVerifyAndSetSession Passkey 验证完成后设置 session // 这是一个辅助函数,供 PasskeyVerifyFinish 调用 func PasskeyVerifyAndSetSession(c *gin.Context) { diff --git a/middleware/auth.go b/middleware/auth.go index 0bb27ead0..f5a8630ff 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -168,6 +168,63 @@ func WssAuth(c *gin.Context) { } +// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。 +// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。 +// 即使令牌已过期、已耗尽或已禁用,也允许访问。 +// 仍然检查用户是否被封禁。 +func TokenAuthReadOnly() func(c *gin.Context) { + return func(c *gin.Context) { + key := c.Request.Header.Get("Authorization") + if key == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "message": "未提供 Authorization 请求头", + }) + c.Abort() + return + } + if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") { + key = strings.TrimSpace(key[7:]) + } + key = strings.TrimPrefix(key, "sk-") + parts := strings.Split(key, "-") + key = parts[0] + + token, err := model.GetTokenByKey(key, false) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "message": "无效的令牌", + }) + c.Abort() + return + } + + userCache, err := model.GetUserCache(token.UserId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "success": false, + "message": err.Error(), + }) + c.Abort() + return + } + if userCache.Status != common.UserStatusEnabled { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "用户已被封禁", + }) + c.Abort() + return + } + + c.Set("id", token.UserId) + c.Set("token_id", token.Id) + c.Set("token_key", token.Key) + c.Next() + } +} + func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { // 先检测是否为ws diff --git a/model/log.go b/model/log.go index de6628e7f..664c180be 100644 --- a/model/log.go +++ b/model/log.go @@ -2,9 +2,8 @@ package model import ( "context" + "errors" "fmt" - "os" - "strings" "time" "github.com/QuantumNous/new-api/common" @@ -66,16 +65,8 @@ func formatUserLogs(logs []*Log) { } } -func GetLogByKey(key string) (logs []*Log, err error) { - if os.Getenv("LOG_SQL_DSN") != "" { - var tk Token - if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil { - return nil, err - } - err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error - } else { - err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error - } +func GetLogByTokenId(tokenId int) (logs []*Log, err error) { + err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error formatUserLogs(logs) return logs, err } @@ -276,6 +267,8 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName return logs, total, err } +const logSearchCountLimit = 10000 + func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) { var tx *gorm.DB if logType == LogTypeUnknown { @@ -285,7 +278,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } if modelName != "" { - tx = tx.Where("logs.model_name like ?", modelName) + modelNamePattern, err := sanitizeLikePattern(modelName) + if err != nil { + return nil, 0, err + } + tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern) } if tokenName != "" { tx = tx.Where("logs.token_name = ?", tokenName) @@ -302,37 +299,28 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int if group != "" { tx = tx.Where("logs."+logGroupCol+" = ?", group) } - err = tx.Model(&Log{}).Count(&total).Error + err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error if err != nil { - return nil, 0, err + common.SysError("failed to count user logs: " + err.Error()) + return nil, 0, errors.New("查询日志失败") } err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error if err != nil { - return nil, 0, err + common.SysError("failed to search user logs: " + err.Error()) + return nil, 0, errors.New("查询日志失败") } formatUserLogs(logs) return logs, total, err } -func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error - return logs, err -} - -func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error - formatUserLogs(logs) - return logs, err -} - type Stat struct { Quota int `json:"quota"` Rpm int `json:"rpm"` Tpm int `json:"tpm"` } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) { +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) { tx := LOG_DB.Table("logs").Select("sum(quota) quota") // 为rpm和tpm创建单独的查询 @@ -353,8 +341,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa tx = tx.Where("created_at <= ?", endTimestamp) } if modelName != "" { - tx = tx.Where("model_name like ?", modelName) - rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName) + modelNamePattern, err := sanitizeLikePattern(modelName) + if err != nil { + return stat, err + } + tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern) + rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern) } if channel != 0 { tx = tx.Where("channel_id = ?", channel) @@ -372,10 +364,16 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix()) // 执行查询 - tx.Scan(&stat) - rpmTpmQuery.Scan(&stat) + if err := tx.Scan(&stat).Error; err != nil { + common.SysError("failed to query log stat: " + err.Error()) + return stat, errors.New("查询统计数据失败") + } + if err := rpmTpmQuery.Scan(&stat).Error; err != nil { + common.SysError("failed to query rpm/tpm stat: " + err.Error()) + return stat, errors.New("查询统计数据失败") + } - return stat + return stat, nil } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { diff --git a/router/api-router.go b/router/api-router.go index 7b1cdef67..e2ef2f531 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -50,7 +50,6 @@ func SetApiRouter(router *gin.Engine) { // Universal secure verification routes apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify) - apiRouter.GET("/verify/status", middleware.UserAuth(), controller.GetVerificationStatus) userRoute := apiRouter.Group("/user") { @@ -247,10 +246,10 @@ func SetApiRouter(router *gin.Engine) { } usageRoute := apiRouter.Group("/usage") - usageRoute.Use(middleware.CriticalRateLimit()) + usageRoute.Use(middleware.CORS(), middleware.CriticalRateLimit()) { tokenUsageRoute := usageRoute.Group("/token") - tokenUsageRoute.Use(middleware.TokenAuth()) + tokenUsageRoute.Use(middleware.TokenAuthReadOnly()) { tokenUsageRoute.GET("/", controller.GetTokenUsage) } @@ -275,15 +274,15 @@ func SetApiRouter(router *gin.Engine) { logRoute.GET("/channel_affinity_usage_cache", middleware.AdminAuth(), controller.GetChannelAffinityUsageCacheStats) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) - logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) + logRoute.GET("/self/search", middleware.UserAuth(), middleware.SearchRateLimit(), controller.SearchUserLogs) dataRoute := apiRouter.Group("/data") dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates) dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates) - logRoute.Use(middleware.CORS()) + logRoute.Use(middleware.CORS(), middleware.CriticalRateLimit()) { - logRoute.GET("/token", controller.GetLogByKey) + logRoute.GET("/token", middleware.TokenAuthReadOnly(), controller.GetLogByKey) } groupRoute := apiRouter.Group("/group") groupRoute.Use(middleware.AdminAuth())