mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 01:45:41 +00:00
refactor: enhance API security with read-only token authentication and improved rate limiting
This commit is contained in:
@@ -39,7 +39,7 @@ var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
|
||||
var ItemsPerPage = 10
|
||||
var MaxRecentItems = 100
|
||||
var MaxRecentItems = 1000
|
||||
|
||||
var PasswordLoginEnabled = true
|
||||
var PasswordRegisterEnabled = true
|
||||
|
||||
@@ -53,40 +53,32 @@ func GetUserLogs(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Deprecated: SearchAllLogs 已废弃,前端未使用该接口。
|
||||
func SearchAllLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
logs, err := model.SearchAllLogs(keyword)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"success": false,
|
||||
"message": "该接口已废弃",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Deprecated: SearchUserLogs 已废弃,前端未使用该接口。
|
||||
func SearchUserLogs(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
userId := c.GetInt("id")
|
||||
logs, err := model.SearchUserLogs(userId, keyword)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"success": false,
|
||||
"message": "该接口已废弃",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetLogByKey(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
logs, err := model.GetLogByKey(key)
|
||||
tokenId := c.GetInt("token_id")
|
||||
if tokenId == 0 {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的令牌",
|
||||
})
|
||||
return
|
||||
}
|
||||
logs, err := model.GetLogByTokenId(tokenId)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
@@ -110,7 +102,11 @@ func GetLogsStat(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
stat, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -133,7 +129,11 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
group := c.Query("group")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
quotaNum, err := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
|
||||
@@ -133,94 +133,6 @@ func UniversalVerify(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetVerificationStatus 获取验证状态
|
||||
func GetVerificationStatus(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
if userId == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
verifiedAtRaw := session.Get(SecureVerificationSessionKey)
|
||||
|
||||
if verifiedAtRaw == nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": VerificationStatusResponse{
|
||||
Verified: false,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
verifiedAt, ok := verifiedAtRaw.(int64)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": VerificationStatusResponse{
|
||||
Verified: false,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Now().Unix() - verifiedAt
|
||||
if elapsed >= SecureVerificationTimeout {
|
||||
// 验证已过期
|
||||
session.Delete(SecureVerificationSessionKey)
|
||||
_ = session.Save()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": VerificationStatusResponse{
|
||||
Verified: false,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": VerificationStatusResponse{
|
||||
Verified: true,
|
||||
ExpiresAt: verifiedAt + SecureVerificationTimeout,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// CheckSecureVerification 检查是否已通过安全验证
|
||||
// 返回 true 表示验证有效,false 表示需要重新验证
|
||||
func CheckSecureVerification(c *gin.Context) bool {
|
||||
session := sessions.Default(c)
|
||||
verifiedAtRaw := session.Get(SecureVerificationSessionKey)
|
||||
|
||||
if verifiedAtRaw == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
verifiedAt, ok := verifiedAtRaw.(int64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
elapsed := time.Now().Unix() - verifiedAt
|
||||
if elapsed >= SecureVerificationTimeout {
|
||||
// 验证已过期,清除 session
|
||||
session.Delete(SecureVerificationSessionKey)
|
||||
_ = session.Save()
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// PasskeyVerifyAndSetSession Passkey 验证完成后设置 session
|
||||
// 这是一个辅助函数,供 PasskeyVerifyFinish 调用
|
||||
func PasskeyVerifyAndSetSession(c *gin.Context) {
|
||||
|
||||
@@ -168,6 +168,63 @@ func WssAuth(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
// TokenAuthReadOnly 宽松版本的令牌认证中间件,用于只读查询接口。
|
||||
// 只验证令牌 key 是否存在,不检查令牌状态、过期时间和额度。
|
||||
// 即使令牌已过期、已耗尽或已禁用,也允许访问。
|
||||
// 仍然检查用户是否被封禁。
|
||||
func TokenAuthReadOnly() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
if key == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "未提供 Authorization 请求头",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(key, "Bearer ") || strings.HasPrefix(key, "bearer ") {
|
||||
key = strings.TrimSpace(key[7:])
|
||||
}
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts := strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
|
||||
token, err := model.GetTokenByKey(key, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的令牌",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
userCache, err := model.GetUserCache(token.UserId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if userCache.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已被封禁",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_key", token.Key)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
// 先检测是否为ws
|
||||
|
||||
64
model/log.go
64
model/log.go
@@ -2,9 +2,8 @@ package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -66,16 +65,8 @@ func formatUserLogs(logs []*Log) {
|
||||
}
|
||||
}
|
||||
|
||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
var tk Token
|
||||
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
||||
} else {
|
||||
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
|
||||
}
|
||||
func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
|
||||
err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
||||
formatUserLogs(logs)
|
||||
return logs, err
|
||||
}
|
||||
@@ -276,6 +267,8 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
const logSearchCountLimit = 10000
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
@@ -285,7 +278,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
}
|
||||
|
||||
if modelName != "" {
|
||||
tx = tx.Where("logs.model_name like ?", modelName)
|
||||
modelNamePattern, err := sanitizeLikePattern(modelName)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
||||
}
|
||||
if tokenName != "" {
|
||||
tx = tx.Where("logs.token_name = ?", tokenName)
|
||||
@@ -302,37 +299,28 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
common.SysError("failed to count user logs: " + err.Error())
|
||||
return nil, 0, errors.New("查询日志失败")
|
||||
}
|
||||
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
common.SysError("failed to search user logs: " + err.Error())
|
||||
return nil, 0, errors.New("查询日志失败")
|
||||
}
|
||||
|
||||
formatUserLogs(logs)
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
||||
formatUserLogs(logs)
|
||||
return logs, err
|
||||
}
|
||||
|
||||
type Stat struct {
|
||||
Quota int `json:"quota"`
|
||||
Rpm int `json:"rpm"`
|
||||
Tpm int `json:"tpm"`
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
|
||||
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
||||
|
||||
// 为rpm和tpm创建单独的查询
|
||||
@@ -353,8 +341,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if modelName != "" {
|
||||
tx = tx.Where("model_name like ?", modelName)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
|
||||
modelNamePattern, err := sanitizeLikePattern(modelName)
|
||||
if err != nil {
|
||||
return stat, err
|
||||
}
|
||||
tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
@@ -372,10 +364,16 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
|
||||
|
||||
// 执行查询
|
||||
tx.Scan(&stat)
|
||||
rpmTpmQuery.Scan(&stat)
|
||||
if err := tx.Scan(&stat).Error; err != nil {
|
||||
common.SysError("failed to query log stat: " + err.Error())
|
||||
return stat, errors.New("查询统计数据失败")
|
||||
}
|
||||
if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
|
||||
common.SysError("failed to query rpm/tpm stat: " + err.Error())
|
||||
return stat, errors.New("查询统计数据失败")
|
||||
}
|
||||
|
||||
return stat
|
||||
return stat, nil
|
||||
}
|
||||
|
||||
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|
||||
|
||||
@@ -50,7 +50,6 @@ func SetApiRouter(router *gin.Engine) {
|
||||
|
||||
// Universal secure verification routes
|
||||
apiRouter.POST("/verify", middleware.UserAuth(), middleware.CriticalRateLimit(), controller.UniversalVerify)
|
||||
apiRouter.GET("/verify/status", middleware.UserAuth(), controller.GetVerificationStatus)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
@@ -247,10 +246,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
}
|
||||
|
||||
usageRoute := apiRouter.Group("/usage")
|
||||
usageRoute.Use(middleware.CriticalRateLimit())
|
||||
usageRoute.Use(middleware.CORS(), middleware.CriticalRateLimit())
|
||||
{
|
||||
tokenUsageRoute := usageRoute.Group("/token")
|
||||
tokenUsageRoute.Use(middleware.TokenAuth())
|
||||
tokenUsageRoute.Use(middleware.TokenAuthReadOnly())
|
||||
{
|
||||
tokenUsageRoute.GET("/", controller.GetTokenUsage)
|
||||
}
|
||||
@@ -275,15 +274,15 @@ func SetApiRouter(router *gin.Engine) {
|
||||
logRoute.GET("/channel_affinity_usage_cache", middleware.AdminAuth(), controller.GetChannelAffinityUsageCacheStats)
|
||||
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
|
||||
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
|
||||
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
|
||||
logRoute.GET("/self/search", middleware.UserAuth(), middleware.SearchRateLimit(), controller.SearchUserLogs)
|
||||
|
||||
dataRoute := apiRouter.Group("/data")
|
||||
dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
|
||||
dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
|
||||
|
||||
logRoute.Use(middleware.CORS())
|
||||
logRoute.Use(middleware.CORS(), middleware.CriticalRateLimit())
|
||||
{
|
||||
logRoute.GET("/token", controller.GetLogByKey)
|
||||
logRoute.GET("/token", middleware.TokenAuthReadOnly(), controller.GetLogByKey)
|
||||
}
|
||||
groupRoute := apiRouter.Group("/group")
|
||||
groupRoute.Use(middleware.AdminAuth())
|
||||
|
||||
Reference in New Issue
Block a user