refactor: enhance API security with read-only token authentication and improved rate limiting

This commit is contained in:
CaIon
2026-02-06 21:26:26 +08:00
parent 04dd761880
commit d814d62e2f
6 changed files with 119 additions and 153 deletions

View File

@@ -39,7 +39,7 @@ var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10 var ItemsPerPage = 10
var MaxRecentItems = 100 var MaxRecentItems = 1000
var PasswordLoginEnabled = true var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true

View File

@@ -53,40 +53,32 @@ func GetUserLogs(c *gin.Context) {
return return
} }
// Deprecated: SearchAllLogs 已废弃,前端未使用该接口。
func SearchAllLogs(c *gin.Context) { func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": false,
"message": "", "message": "该接口已废弃",
"data": logs,
}) })
return
} }
// Deprecated: SearchUserLogs 已废弃,前端未使用该接口。
func SearchUserLogs(c *gin.Context) { func SearchUserLogs(c *gin.Context) {
keyword := c.Query("keyword")
userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": false,
"message": "", "message": "该接口已废弃",
"data": logs,
}) })
return
} }
func GetLogByKey(c *gin.Context) { func GetLogByKey(c *gin.Context) {
key := c.Query("key") tokenId := c.GetInt("token_id")
logs, err := model.GetLogByKey(key) if tokenId == 0 {
c.JSON(200, gin.H{
"success": false,
"message": "无效的令牌",
})
return
}
logs, err := model.GetLogByTokenId(tokenId)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": false, "success": false,
@@ -110,7 +102,11 @@ func GetLogsStat(c *gin.Context) {
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group") group := c.Query("group")
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) stat, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
if err != nil {
common.ApiError(c, err)
return
}
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@@ -133,7 +129,11 @@ func GetLogsSelfStat(c *gin.Context) {
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) channel, _ := strconv.Atoi(c.Query("channel"))
group := c.Query("group") group := c.Query("group")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group) quotaNum, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
if err != nil {
common.ApiError(c, err)
return
}
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,

View File

@@ -133,94 +133,6 @@ func UniversalVerify(c *gin.Context) {
}) })
} }
// GetVerificationStatus 获取验证状态
func GetVerificationStatus(c *gin.Context) {
userId := c.GetInt("id")
if userId == 0 {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "未登录",
})
return
}
session := sessions.Default(c)
verifiedAtRaw := session.Get(SecureVerificationSessionKey)
if verifiedAtRaw == nil {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: false,
},
})
return
}
verifiedAt, ok := verifiedAtRaw.(int64)
if !ok {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: false,
},
})
return
}
elapsed := time.Now().Unix() - verifiedAt
if elapsed >= SecureVerificationTimeout {
// 验证已过期
session.Delete(SecureVerificationSessionKey)
_ = session.Save()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: false,
},
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": VerificationStatusResponse{
Verified: true,
ExpiresAt: verifiedAt + SecureVerificationTimeout,
},
})
}
// CheckSecureVerification 检查是否已通过安全验证
// 返回 true 表示验证有效false 表示需要重新验证
func CheckSecureVerification(c *gin.Context) bool {
session := sessions.Default(c)
verifiedAtRaw := session.Get(SecureVerificationSessionKey)
if verifiedAtRaw == nil {
return false
}
verifiedAt, ok := verifiedAtRaw.(int64)
if !ok {
return false
}
elapsed := time.Now().Unix() - verifiedAt
if elapsed >= SecureVerificationTimeout {
// 验证已过期,清除 session
session.Delete(SecureVerificationSessionKey)
_ = session.Save()
return false
}
return true
}
// PasskeyVerifyAndSetSession Passkey 验证完成后设置 session // PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
// 这是一个辅助函数,供 PasskeyVerifyFinish 调用 // 这是一个辅助函数,供 PasskeyVerifyFinish 调用
func PasskeyVerifyAndSetSession(c *gin.Context) { func PasskeyVerifyAndSetSession(c *gin.Context) {

View File

@@ -168,6 +168,63 @@ func WssAuth(c *gin.Context) {
} }
// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
// 即使令牌已过期、已耗尽或已禁用,也允许访问。
// 仍然检查用户是否被封禁。
func TokenAuthReadOnly() func(c *gin.Context) {
return func(c *gin.Context) {
key := c.Request.Header.Get("Authorization")
if key == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "未提供 Authorization 请求头",
})
c.Abort()
return
}
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
key = strings.TrimSpace(key[7:])
}
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
key = parts[0]
token, err := model.GetTokenByKey(key, false)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "无效的令牌",
})
c.Abort()
return
}
userCache, err := model.GetUserCache(token.UserId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
c.Abort()
return
}
if userCache.Status != common.UserStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "用户已被封禁",
})
c.Abort()
return
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_key", token.Key)
c.Next()
}
}
func TokenAuth() func(c *gin.Context) { func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
// 先检测是否为ws // 先检测是否为ws

View File

@@ -2,9 +2,8 @@ package model
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os"
"strings"
"time" "time"
"github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/common"
@@ -66,16 +65,8 @@ func formatUserLogs(logs []*Log) {
} }
} }
func GetLogByKey(key string) (logs []*Log, err error) { func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" { err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
var tk Token
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
} else {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
}
formatUserLogs(logs) formatUserLogs(logs)
return logs, err return logs, err
} }
@@ -276,6 +267,8 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return logs, total, err return logs, total, err
} }
const logSearchCountLimit = 10000
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) { func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
@@ -285,7 +278,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("logs.model_name like ?", modelName) modelNamePattern, err := sanitizeLikePattern(modelName)
if err != nil {
return nil, 0, err
}
tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("logs.token_name = ?", tokenName) tx = tx.Where("logs.token_name = ?", tokenName)
@@ -302,37 +299,28 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if group != "" { if group != "" {
tx = tx.Where("logs."+logGroupCol+" = ?", group) tx = tx.Where("logs."+logGroupCol+" = ?", group)
} }
err = tx.Model(&Log{}).Count(&total).Error err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error
if err != nil { if err != nil {
return nil, 0, err common.SysError("failed to count user logs: " + err.Error())
return nil, 0, errors.New("查询日志失败")
} }
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
if err != nil { if err != nil {
return nil, 0, err common.SysError("failed to search user logs: " + err.Error())
return nil, 0, errors.New("查询日志失败")
} }
formatUserLogs(logs) formatUserLogs(logs)
return logs, total, err return logs, total, err
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
formatUserLogs(logs)
return logs, err
}
type Stat struct { type Stat struct {
Quota int `json:"quota"` Quota int `json:"quota"`
Rpm int `json:"rpm"` Rpm int `json:"rpm"`
Tpm int `json:"tpm"` Tpm int `json:"tpm"`
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota") tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询 // 为rpm和tpm创建单独的查询
@@ -353,8 +341,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", modelName) modelNamePattern, err := sanitizeLikePattern(modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName) if err != nil {
return stat, err
}
tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
@@ -372,10 +364,16 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix()) rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询 // 执行查询
tx.Scan(&stat) if err := tx.Scan(&stat).Error; err != nil {
rpmTpmQuery.Scan(&stat) common.SysError("failed to query log stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
common.SysError("failed to query rpm/tpm stat: " + err.Error())
return stat, errors.New("查询统计数据失败")
}
return stat return stat, nil
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {

View File

@@ -50,7 +50,6 @@ func SetApiRouter(router *gin.Engine) {
// Universal secure verification routes // Universal secure verification routes
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify) apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
apiRouter.GET("/verify/status", middleware.UserAuth(), controller.GetVerificationStatus)
userRoute := apiRouter.Group("/user") userRoute := apiRouter.Group("/user")
{ {
@@ -247,10 +246,10 @@ func SetApiRouter(router *gin.Engine) {
} }
usageRoute := apiRouter.Group("/usage") usageRoute := apiRouter.Group("/usage")
usageRoute.Use(middleware.CriticalRateLimit()) usageRoute.Use(middleware.CORS(), middleware.CriticalRateLimit())
{ {
tokenUsageRoute := usageRoute.Group("/token") tokenUsageRoute := usageRoute.Group("/token")
tokenUsageRoute.Use(middleware.TokenAuth()) tokenUsageRoute.Use(middleware.TokenAuthReadOnly())
{ {
tokenUsageRoute.GET("/", controller.GetTokenUsage) tokenUsageRoute.GET("/", controller.GetTokenUsage)
} }
@@ -275,15 +274,15 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/channel_affinity_usage_cache", middleware.AdminAuth(), controller.GetChannelAffinityUsageCacheStats) logRoute.GET("/channel_affinity_usage_cache", middleware.AdminAuth(), controller.GetChannelAffinityUsageCacheStats)
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) logRoute.GET("/self/search", middleware.UserAuth(), middleware.SearchRateLimit(), controller.SearchUserLogs)
dataRoute := apiRouter.Group("/data") dataRoute := apiRouter.Group("/data")
dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates) dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates) dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
logRoute.Use(middleware.CORS()) logRoute.Use(middleware.CORS(), middleware.CriticalRateLimit())
{ {
logRoute.GET("/token", controller.GetLogByKey) logRoute.GET("/token", middleware.TokenAuthReadOnly(), controller.GetLogByKey)
} }
groupRoute := apiRouter.Group("/group") groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth()) groupRoute.Use(middleware.AdminAuth())