mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 09:33:10 +00:00
Fixes #3033 — users with many model limits hit PostgreSQL's varchar length constraint. The text type is supported across all three databases (SQLite, MySQL, PostgreSQL) with no length restriction.
463 lines
13 KiB
Go
463 lines
13 KiB
Go
package model
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/setting/operation_setting"
|
||
"github.com/bytedance/gopkg/util/gopool"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type Token struct {
|
||
Id int `json:"id"`
|
||
UserId int `json:"user_id" gorm:"index"`
|
||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||
Status int `json:"status" gorm:"default:1"`
|
||
Name string `json:"name" gorm:"index" `
|
||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||
UnlimitedQuota bool `json:"unlimited_quota"`
|
||
ModelLimitsEnabled bool `json:"model_limits_enabled"`
|
||
ModelLimits string `json:"model_limits" gorm:"type:text"`
|
||
AllowIps *string `json:"allow_ips" gorm:"default:''"`
|
||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||
Group string `json:"group" gorm:"default:''"`
|
||
CrossGroupRetry bool `json:"cross_group_retry"` // 跨分组重试,仅auto分组有效
|
||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||
}
|
||
|
||
func (token *Token) Clean() {
|
||
token.Key = ""
|
||
}
|
||
|
||
func (token *Token) GetIpLimits() []string {
|
||
// delete empty spaces
|
||
//split with \n
|
||
ipLimits := make([]string, 0)
|
||
if token.AllowIps == nil {
|
||
return ipLimits
|
||
}
|
||
cleanIps := strings.ReplaceAll(*token.AllowIps, " ", "")
|
||
if cleanIps == "" {
|
||
return ipLimits
|
||
}
|
||
ips := strings.Split(cleanIps, "\n")
|
||
for _, ip := range ips {
|
||
ip = strings.TrimSpace(ip)
|
||
ip = strings.ReplaceAll(ip, ",", "")
|
||
if ip != "" {
|
||
ipLimits = append(ipLimits, ip)
|
||
}
|
||
}
|
||
return ipLimits
|
||
}
|
||
|
||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||
var tokens []*Token
|
||
var err error
|
||
err = DB.Where("user_id = ?", userId).Order("id desc").Limit(num).Offset(startIdx).Find(&tokens).Error
|
||
return tokens, err
|
||
}
|
||
|
||
// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
|
||
// 规则:
|
||
// 1. 转义 ! 和 _(使用 ! 作为 ESCAPE 字符,兼容 MySQL/PostgreSQL/SQLite)
|
||
// 2. 连续的 % 合并为单个 %
|
||
// 3. 最多允许 2 个 %
|
||
// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
|
||
// 5. 不含 % 时按精确匹配
|
||
func sanitizeLikePattern(input string) (string, error) {
|
||
// 1. 先转义 ESCAPE 字符 ! 自身,再转义 _
|
||
// 使用 ! 而非 \ 作为 ESCAPE 字符,避免 MySQL 中反斜杠的字符串转义问题
|
||
input = strings.ReplaceAll(input, "!", "!!")
|
||
input = strings.ReplaceAll(input, `_`, `!_`)
|
||
|
||
// 2. 连续的 % 直接拒绝
|
||
if strings.Contains(input, "%%") {
|
||
return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
|
||
}
|
||
|
||
// 3. 统计 % 数量,不得超过 2
|
||
count := strings.Count(input, "%")
|
||
if count > 2 {
|
||
return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
|
||
}
|
||
|
||
// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
|
||
if count > 0 {
|
||
stripped := strings.ReplaceAll(input, "%", "")
|
||
if len(stripped) < 2 {
|
||
return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
|
||
}
|
||
return input, nil
|
||
}
|
||
|
||
// 5. 无 % 时,精确全匹配
|
||
return input, nil
|
||
}
|
||
|
||
const searchHardLimit = 100
|
||
|
||
func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
|
||
// model 层强制截断
|
||
if limit <= 0 || limit > searchHardLimit {
|
||
limit = searchHardLimit
|
||
}
|
||
if offset < 0 {
|
||
offset = 0
|
||
}
|
||
|
||
if token != "" {
|
||
token = strings.TrimPrefix(token, "sk-")
|
||
}
|
||
|
||
// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
|
||
maxTokens := operation_setting.GetMaxUserTokens()
|
||
hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
|
||
if hasFuzzy {
|
||
count, err := CountUserTokens(userId)
|
||
if err != nil {
|
||
common.SysLog("failed to count user tokens: " + err.Error())
|
||
return nil, 0, errors.New("获取令牌数量失败")
|
||
}
|
||
if int(count) > maxTokens {
|
||
return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
|
||
}
|
||
}
|
||
|
||
baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
|
||
|
||
// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
|
||
if keyword != "" {
|
||
keywordPattern, err := sanitizeLikePattern(keyword)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
baseQuery = baseQuery.Where("name LIKE ? ESCAPE '!'", keywordPattern)
|
||
}
|
||
if token != "" {
|
||
tokenPattern, err := sanitizeLikePattern(token)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '!'", tokenPattern)
|
||
}
|
||
|
||
// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
|
||
err = baseQuery.Limit(maxTokens).Count(&total).Error
|
||
if err != nil {
|
||
common.SysError("failed to count search tokens: " + err.Error())
|
||
return nil, 0, errors.New("搜索令牌失败")
|
||
}
|
||
|
||
// 再分页查数据
|
||
err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
|
||
if err != nil {
|
||
common.SysError("failed to search tokens: " + err.Error())
|
||
return nil, 0, errors.New("搜索令牌失败")
|
||
}
|
||
return tokens, total, nil
|
||
}
|
||
|
||
func ValidateUserToken(key string) (token *Token, err error) {
|
||
if key == "" {
|
||
return nil, errors.New("未提供令牌")
|
||
}
|
||
token, err = GetTokenByKey(key, false)
|
||
if err == nil {
|
||
if token.Status == common.TokenStatusExhausted {
|
||
keyPrefix := key[:3]
|
||
keySuffix := key[len(key)-3:]
|
||
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
|
||
} else if token.Status == common.TokenStatusExpired {
|
||
return token, errors.New("该令牌已过期")
|
||
}
|
||
if token.Status != common.TokenStatusEnabled {
|
||
return token, errors.New("该令牌状态不可用")
|
||
}
|
||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||
if !common.RedisEnabled {
|
||
token.Status = common.TokenStatusExpired
|
||
err := token.SelectUpdate()
|
||
if err != nil {
|
||
common.SysLog("failed to update token status" + err.Error())
|
||
}
|
||
}
|
||
return token, errors.New("该令牌已过期")
|
||
}
|
||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||
if !common.RedisEnabled {
|
||
// in this case, we can make sure the token is exhausted
|
||
token.Status = common.TokenStatusExhausted
|
||
err := token.SelectUpdate()
|
||
if err != nil {
|
||
common.SysLog("failed to update token status" + err.Error())
|
||
}
|
||
}
|
||
keyPrefix := key[:3]
|
||
keySuffix := key[len(key)-3:]
|
||
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
|
||
}
|
||
return token, nil
|
||
}
|
||
common.SysLog("ValidateUserToken: failed to get token: " + err.Error())
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, errors.New("无效的令牌")
|
||
} else {
|
||
return nil, errors.New("无效的令牌,数据库查询出错,请联系管理员")
|
||
}
|
||
}
|
||
|
||
func GetTokenByIds(id int, userId int) (*Token, error) {
|
||
if id == 0 || userId == 0 {
|
||
return nil, errors.New("id 或 userId 为空!")
|
||
}
|
||
token := Token{Id: id, UserId: userId}
|
||
var err error = nil
|
||
err = DB.First(&token, "id = ? and user_id = ?", id, userId).Error
|
||
return &token, err
|
||
}
|
||
|
||
func GetTokenById(id int) (*Token, error) {
|
||
if id == 0 {
|
||
return nil, errors.New("id 为空!")
|
||
}
|
||
token := Token{Id: id}
|
||
var err error = nil
|
||
err = DB.First(&token, "id = ?", id).Error
|
||
if shouldUpdateRedis(true, err) {
|
||
gopool.Go(func() {
|
||
if err := cacheSetToken(token); err != nil {
|
||
common.SysLog("failed to update user status cache: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
return &token, err
|
||
}
|
||
|
||
func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||
defer func() {
|
||
// Update Redis cache asynchronously on successful DB read
|
||
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||
gopool.Go(func() {
|
||
if err := cacheSetToken(*token); err != nil {
|
||
common.SysLog("failed to update user status cache: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
}()
|
||
if !fromDB && common.RedisEnabled {
|
||
// Try Redis first
|
||
token, err := cacheGetTokenByKey(key)
|
||
if err == nil {
|
||
return token, nil
|
||
}
|
||
// Don't return error - fall through to DB
|
||
}
|
||
fromDB = true
|
||
err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
|
||
return token, err
|
||
}
|
||
|
||
func (token *Token) Insert() error {
|
||
var err error
|
||
err = DB.Create(token).Error
|
||
return err
|
||
}
|
||
|
||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||
func (token *Token) Update() (err error) {
|
||
defer func() {
|
||
if shouldUpdateRedis(true, err) {
|
||
gopool.Go(func() {
|
||
err := cacheSetToken(*token)
|
||
if err != nil {
|
||
common.SysLog("failed to update token cache: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
}()
|
||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
||
"model_limits_enabled", "model_limits", "allow_ips", "group", "cross_group_retry").Updates(token).Error
|
||
return err
|
||
}
|
||
|
||
func (token *Token) SelectUpdate() (err error) {
|
||
defer func() {
|
||
if shouldUpdateRedis(true, err) {
|
||
gopool.Go(func() {
|
||
err := cacheSetToken(*token)
|
||
if err != nil {
|
||
common.SysLog("failed to update token cache: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
}()
|
||
// This can update zero values
|
||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
||
}
|
||
|
||
func (token *Token) Delete() (err error) {
|
||
defer func() {
|
||
if shouldUpdateRedis(true, err) {
|
||
gopool.Go(func() {
|
||
err := cacheDeleteToken(token.Key)
|
||
if err != nil {
|
||
common.SysLog("failed to delete token cache: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
}()
|
||
err = DB.Delete(token).Error
|
||
return err
|
||
}
|
||
|
||
func (token *Token) IsModelLimitsEnabled() bool {
|
||
return token.ModelLimitsEnabled
|
||
}
|
||
|
||
func (token *Token) GetModelLimits() []string {
|
||
if token.ModelLimits == "" {
|
||
return []string{}
|
||
}
|
||
return strings.Split(token.ModelLimits, ",")
|
||
}
|
||
|
||
func (token *Token) GetModelLimitsMap() map[string]bool {
|
||
limits := token.GetModelLimits()
|
||
limitsMap := make(map[string]bool)
|
||
for _, limit := range limits {
|
||
limitsMap[limit] = true
|
||
}
|
||
return limitsMap
|
||
}
|
||
|
||
func DisableModelLimits(tokenId int) error {
|
||
token, err := GetTokenById(tokenId)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
token.ModelLimitsEnabled = false
|
||
token.ModelLimits = ""
|
||
return token.Update()
|
||
}
|
||
|
||
func DeleteTokenById(id int, userId int) (err error) {
|
||
// Why we need userId here? In case user want to delete other's token.
|
||
if id == 0 || userId == 0 {
|
||
return errors.New("id 或 userId 为空!")
|
||
}
|
||
token := Token{Id: id, UserId: userId}
|
||
err = DB.Where(token).First(&token).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return token.Delete()
|
||
}
|
||
|
||
func IncreaseTokenQuota(tokenId int, key string, quota int) (err error) {
|
||
if quota < 0 {
|
||
return errors.New("quota 不能为负数!")
|
||
}
|
||
if common.RedisEnabled {
|
||
gopool.Go(func() {
|
||
err := cacheIncrTokenQuota(key, int64(quota))
|
||
if err != nil {
|
||
common.SysLog("failed to increase token quota: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
if common.BatchUpdateEnabled {
|
||
addNewRecord(BatchUpdateTypeTokenQuota, tokenId, quota)
|
||
return nil
|
||
}
|
||
return increaseTokenQuota(tokenId, quota)
|
||
}
|
||
|
||
func increaseTokenQuota(id int, quota int) (err error) {
|
||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||
map[string]interface{}{
|
||
"remain_quota": gorm.Expr("remain_quota + ?", quota),
|
||
"used_quota": gorm.Expr("used_quota - ?", quota),
|
||
"accessed_time": common.GetTimestamp(),
|
||
},
|
||
).Error
|
||
return err
|
||
}
|
||
|
||
func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
||
if quota < 0 {
|
||
return errors.New("quota 不能为负数!")
|
||
}
|
||
if common.RedisEnabled {
|
||
gopool.Go(func() {
|
||
err := cacheDecrTokenQuota(key, int64(quota))
|
||
if err != nil {
|
||
common.SysLog("failed to decrease token quota: " + err.Error())
|
||
}
|
||
})
|
||
}
|
||
if common.BatchUpdateEnabled {
|
||
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||
return nil
|
||
}
|
||
return decreaseTokenQuota(id, quota)
|
||
}
|
||
|
||
func decreaseTokenQuota(id int, quota int) (err error) {
|
||
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
|
||
map[string]interface{}{
|
||
"remain_quota": gorm.Expr("remain_quota - ?", quota),
|
||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||
"accessed_time": common.GetTimestamp(),
|
||
},
|
||
).Error
|
||
return err
|
||
}
|
||
|
||
// CountUserTokens returns total number of tokens for the given user, used for pagination
|
||
func CountUserTokens(userId int) (int64, error) {
|
||
var total int64
|
||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||
return total, err
|
||
}
|
||
|
||
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
|
||
func BatchDeleteTokens(ids []int, userId int) (int, error) {
|
||
if len(ids) == 0 {
|
||
return 0, errors.New("ids 不能为空!")
|
||
}
|
||
|
||
tx := DB.Begin()
|
||
|
||
var tokens []Token
|
||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
|
||
tx.Rollback()
|
||
return 0, err
|
||
}
|
||
|
||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
|
||
tx.Rollback()
|
||
return 0, err
|
||
}
|
||
|
||
if err := tx.Commit().Error; err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
if common.RedisEnabled {
|
||
gopool.Go(func() {
|
||
for _, t := range tokens {
|
||
_ = cacheDeleteToken(t.Key)
|
||
}
|
||
})
|
||
}
|
||
|
||
return len(tokens), nil
|
||
}
|