mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-07 11:27:26 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4f17543cb | ||
|
|
1eb706de7a | ||
|
|
d13d81baba | ||
|
|
65af1a4d10 | ||
|
|
1ae0a3fb83 | ||
|
|
fe2e8f1a42 | ||
|
|
a5f7f8af29 | ||
|
|
2f01a2125f | ||
|
|
e4f9787c16 | ||
|
|
bb5e032dd2 | ||
|
|
304c92ceab | ||
|
|
05874dcca5 | ||
|
|
ca8b7ed1c3 | ||
|
|
ed435e5c8f | ||
|
|
a1b864bc5e | ||
|
|
2a15dfccea | ||
|
|
9e5a7ed541 | ||
|
|
65d1cde8fb | ||
|
|
8f4a2df5ee |
@@ -83,6 +83,7 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
||||
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
|
||||
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
||||
- `CRYPTO_SECRET`: Encryption key for encrypting database content
|
||||
|
||||
## Deployment
|
||||
> [!TIP]
|
||||
@@ -93,6 +94,10 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
||||
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
> ```
|
||||
|
||||
### Multi-Server Deployment
|
||||
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
|
||||
- If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment.
|
||||
|
||||
### Requirements
|
||||
- Local database (default): SQLite (Docker deployment must mount `/data` directory)
|
||||
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
|
||||
|
||||
@@ -89,6 +89,7 @@
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
|
||||
## 部署
|
||||
> [!TIP]
|
||||
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||
@@ -98,6 +99,10 @@
|
||||
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
> ```
|
||||
|
||||
### 多机部署
|
||||
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
|
||||
- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取。
|
||||
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
||||
|
||||
@@ -30,6 +30,7 @@ var DefaultCollapseSidebar = false // default value of collapse sidebar
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
var CryptoSecret = uuid.New().String()
|
||||
|
||||
var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
|
||||
@@ -1,6 +1,23 @@
|
||||
package common
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func GenerateHMACWithKey(key []byte, data string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func GenerateHMAC(data string) string {
|
||||
h := hmac.New(sha256.New, []byte(CryptoSecret))
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func Password2Hash(password string) (string, error) {
|
||||
passwordBytes := []byte(password)
|
||||
|
||||
@@ -22,7 +22,7 @@ func printHelp() {
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func init() {
|
||||
func LoadEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
@@ -45,6 +45,11 @@ func init() {
|
||||
SessionSecret = ss
|
||||
}
|
||||
}
|
||||
if os.Getenv("CRYPTO_SECRET") != "" {
|
||||
CryptoSecret = os.Getenv("CRYPTO_SECRET")
|
||||
} else {
|
||||
CryptoSecret = SessionSecret
|
||||
}
|
||||
if os.Getenv("SQLITE_PATH") != "" {
|
||||
SQLitePath = os.Getenv("SQLITE_PATH")
|
||||
}
|
||||
|
||||
216
common/redis.go
216
common/redis.go
@@ -2,9 +2,15 @@ package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var RDB *redis.Client
|
||||
@@ -56,39 +62,167 @@ func RedisGet(key string) (string, error) {
|
||||
return RDB.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func RedisExpire(key string, expiration time.Duration) error {
|
||||
ctx := context.Background()
|
||||
return RDB.Expire(ctx, key, expiration).Err()
|
||||
}
|
||||
|
||||
func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
||||
ctx := context.Background()
|
||||
return RDB.GetSet(ctx, key, expiration).Result()
|
||||
}
|
||||
//func RedisExpire(key string, expiration time.Duration) error {
|
||||
// ctx := context.Background()
|
||||
// return RDB.Expire(ctx, key, expiration).Err()
|
||||
//}
|
||||
//
|
||||
//func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
||||
// ctx := context.Background()
|
||||
// return RDB.GetSet(ctx, key, expiration).Result()
|
||||
//}
|
||||
|
||||
func RedisDel(key string) error {
|
||||
ctx := context.Background()
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisDecrease(key string, value int64) error {
|
||||
func RedisHDelObj(key string) error {
|
||||
ctx := context.Background()
|
||||
return RDB.HDel(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
ctx := context.Background()
|
||||
|
||||
data := make(map[string]interface{})
|
||||
|
||||
// 使用反射遍历结构体字段
|
||||
v := reflect.ValueOf(obj).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
value := v.Field(i)
|
||||
|
||||
// Skip DeletedAt field
|
||||
if field.Type.String() == "gorm.DeletedAt" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理指针类型
|
||||
if value.Kind() == reflect.Ptr {
|
||||
if value.IsNil() {
|
||||
data[field.Name] = ""
|
||||
continue
|
||||
}
|
||||
value = value.Elem()
|
||||
}
|
||||
|
||||
// 处理布尔类型
|
||||
if value.Kind() == reflect.Bool {
|
||||
data[field.Name] = strconv.FormatBool(value.Bool())
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他类型直接转换为字符串
|
||||
data[field.Name] = fmt.Sprintf("%v", value.Interface())
|
||||
}
|
||||
|
||||
txn := RDB.TxPipeline()
|
||||
txn.HSet(ctx, key, data)
|
||||
txn.Expire(ctx, key, expiration)
|
||||
|
||||
_, err := txn.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHGetObj(key string, obj interface{}) error {
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := RDB.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load hash from Redis: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return fmt.Errorf("key %s not found in Redis", key)
|
||||
}
|
||||
|
||||
// Handle both pointer and non-pointer values
|
||||
val := reflect.ValueOf(obj)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
|
||||
}
|
||||
|
||||
v := val.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldName := field.Name
|
||||
if value, ok := result[fieldName]; ok {
|
||||
fieldValue := v.Field(i)
|
||||
|
||||
// Handle pointer types
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
// Enhanced type handling for Token struct
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.String:
|
||||
fieldValue.SetString(value)
|
||||
case reflect.Int, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.SetInt(intValue)
|
||||
case reflect.Bool:
|
||||
boolValue, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.SetBool(boolValue)
|
||||
case reflect.Struct:
|
||||
// Special handling for gorm.DeletedAt
|
||||
if fieldValue.Type().String() == "gorm.DeletedAt" {
|
||||
if value != "" {
|
||||
timeValue, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RedisIncr Add this function to handle atomic increments
|
||||
func RedisIncr(key string, delta int64) error {
|
||||
// 检查键的剩余生存时间
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil {
|
||||
// 失败则尝试直接减少
|
||||
return RDB.DecrBy(context.Background(), key, value).Err()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
// 如果剩余生存时间大于0,则进行减少操作
|
||||
// 只有在 key 存在且有 TTL 时才需要特殊处理
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
// 开始一个Redis事务
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
// 减少余额
|
||||
decrCmd := txn.DecrBy(ctx, key, value)
|
||||
decrCmd := txn.IncrBy(ctx, key, delta)
|
||||
if err := decrCmd.Err(); err != nil {
|
||||
return err // 如果减少失败,则直接返回错误
|
||||
}
|
||||
@@ -99,8 +233,54 @@ func RedisDecrease(key string, value int64) error {
|
||||
// 执行事务
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
} else {
|
||||
_ = RedisDel(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHIncrBy(key, field string, delta int64) error {
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
incrCmd := txn.HIncrBy(ctx, key, field, delta)
|
||||
if err := incrCmd.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHSetField(key, field string, value interface{}) error {
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
hsetCmd := txn.HSet(ctx, key, field, value)
|
||||
if err := hsetCmd.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
23
constant/cache_key.go
Normal file
23
constant/cache_key.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
UserQuotaKeyFmt = "user_quota:%d"
|
||||
UserEnabledKeyFmt = "user_enabled:%d"
|
||||
UserUsernameKeyFmt = "user_name:%d"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenFiledRemainQuota = "RemainQuota"
|
||||
TokenFieldGroup = "Group"
|
||||
)
|
||||
@@ -21,7 +21,7 @@ func GetSubscription(c *gin.Context) {
|
||||
usedQuota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
remainQuota, err = model.GetUserQuota(userId)
|
||||
remainQuota, err = model.GetUserQuota(userId, false)
|
||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if expiredTime <= 0 {
|
||||
|
||||
@@ -20,15 +20,18 @@ func GetGroups(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetUserGroups(c *gin.Context) {
|
||||
usableGroups := make(map[string]string)
|
||||
usableGroups := make(map[string]map[string]interface{})
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.CacheGetUserGroup(userId)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if _, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = userUsableGroups[groupName]
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = map[string]interface{}{
|
||||
"ratio": ratio,
|
||||
"desc": desc,
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -166,7 +166,7 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId)
|
||||
userGroup, err := model.GetUserGroup(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
@@ -62,5 +64,6 @@ func Playground(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -153,7 +153,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
|
||||
@@ -75,7 +75,7 @@ func RequestEpay(c *gin.Context) {
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
@@ -236,7 +236,7 @@ func RequestAmount(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
|
||||
7
main.go
7
main.go
@@ -33,9 +33,11 @@ var indexPage []byte
|
||||
func main() {
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysError("failed to load .env file: " + err.Error())
|
||||
common.SysLog("Support for .env file is disabled")
|
||||
}
|
||||
|
||||
common.LoadEnv()
|
||||
|
||||
common.SetupLogger()
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
@@ -80,9 +82,6 @@ func main() {
|
||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||
model.InitChannelCache()
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
go model.SyncTokenCache(common.SyncFrequency)
|
||||
}
|
||||
if common.MemoryCacheEnabled {
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
go model.SyncChannelCache(common.SyncFrequency)
|
||||
|
||||
@@ -201,7 +201,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
userEnabled, err := model.IsUserEnabled(token.UserId, false)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
|
||||
@@ -40,7 +40,7 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
userGroup, _ := model.GetUserGroup(userId, false)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
|
||||
308
model/cache.go
308
model/cache.go
@@ -1,209 +1,115 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
|
||||
// 仅用于定时同步缓存
|
||||
var token2UserId = make(map[string]int)
|
||||
var token2UserIdLock sync.RWMutex
|
||||
|
||||
func cacheSetToken(token *Token) error {
|
||||
jsonBytes, err := json.Marshal(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
|
||||
return err
|
||||
}
|
||||
token2UserIdLock.Lock()
|
||||
defer token2UserIdLock.Unlock()
|
||||
token2UserId[token.Key] = token.UserId
|
||||
return nil
|
||||
}
|
||||
|
||||
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetTokenByKey(key)
|
||||
}
|
||||
var token *Token
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
// 如果缓存中不存在,则从数据库中获取
|
||||
token, err = GetTokenByKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = cacheSetToken(token)
|
||||
return token, nil
|
||||
}
|
||||
// 如果缓存中存在,则续期时间
|
||||
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
err = json.Unmarshal([]byte(tokenObjectString), &token)
|
||||
return token, err
|
||||
}
|
||||
|
||||
func SyncTokenCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing tokens from database")
|
||||
token2UserIdLock.Lock()
|
||||
// 从token2UserId中获取所有的key
|
||||
var copyToken2UserId = make(map[string]int)
|
||||
for s, i := range token2UserId {
|
||||
copyToken2UserId[s] = i
|
||||
}
|
||||
token2UserId = make(map[string]int)
|
||||
token2UserIdLock.Unlock()
|
||||
|
||||
for key := range copyToken2UserId {
|
||||
token, err := GetTokenByKey(key)
|
||||
if err != nil {
|
||||
// 如果数据库中不存在,则删除缓存
|
||||
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
|
||||
//delete redis
|
||||
err := common.RedisDel(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
|
||||
}
|
||||
} else {
|
||||
// 如果数据库中存在,先检查redis
|
||||
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
// 如果redis中不存在,则跳过
|
||||
continue
|
||||
}
|
||||
err = cacheSetToken(token)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetUserGroup(id int) (group string, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUserGroup(id)
|
||||
}
|
||||
group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
|
||||
if err != nil {
|
||||
group, err = GetUserGroup(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return group, err
|
||||
}
|
||||
|
||||
func CacheGetUsername(id int) (username string, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUsernameById(id)
|
||||
}
|
||||
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
|
||||
if err != nil {
|
||||
username, err = GetUsernameById(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return username, err
|
||||
}
|
||||
|
||||
func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUserQuota(id)
|
||||
}
|
||||
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||
if err != nil {
|
||||
quota, err = GetUserQuota(id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user quota error: " + err.Error())
|
||||
}
|
||||
return quota, err
|
||||
}
|
||||
quota, err = strconv.Atoi(quotaString)
|
||||
return quota, err
|
||||
}
|
||||
|
||||
func CacheUpdateUserQuota(id int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
quota, err := GetUserQuota(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cacheSetUserQuota(id, quota)
|
||||
}
|
||||
|
||||
func cacheSetUserQuota(id int, quota int) error {
|
||||
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheDecreaseUserQuota(id int, quota int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsUserEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
userEnabled, err := IsUserEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
return userEnabled, err
|
||||
}
|
||||
//func CacheGetUserGroup(id int) (group string, err error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return GetUserGroup(id)
|
||||
// }
|
||||
// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
|
||||
// if err != nil {
|
||||
// group, err = GetUserGroup(id)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
|
||||
// if err != nil {
|
||||
// common.SysError("Redis set user group error: " + err.Error())
|
||||
// }
|
||||
// }
|
||||
// return group, err
|
||||
//}
|
||||
//
|
||||
//func CacheGetUsername(id int) (username string, err error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return GetUsernameById(id)
|
||||
// }
|
||||
// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
|
||||
// if err != nil {
|
||||
// username, err = GetUsernameById(id)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
|
||||
// if err != nil {
|
||||
// common.SysError("Redis set user group error: " + err.Error())
|
||||
// }
|
||||
// }
|
||||
// return username, err
|
||||
//}
|
||||
//
|
||||
//func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return GetUserQuota(id)
|
||||
// }
|
||||
// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||
// if err != nil {
|
||||
// quota, err = GetUserQuota(id)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// return quota, nil
|
||||
// }
|
||||
// quota, err = strconv.Atoi(quotaString)
|
||||
// return quota, nil
|
||||
//}
|
||||
//
|
||||
//func CacheUpdateUserQuota(id int) error {
|
||||
// if !common.RedisEnabled {
|
||||
// return nil
|
||||
// }
|
||||
// quota, err := GetUserQuota(id)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return cacheSetUserQuota(id, quota)
|
||||
//}
|
||||
//
|
||||
//func cacheSetUserQuota(id int, quota int) error {
|
||||
// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//func CacheDecreaseUserQuota(id int, quota int) error {
|
||||
// if !common.RedisEnabled {
|
||||
// return nil
|
||||
// }
|
||||
// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return IsUserEnabled(userId)
|
||||
// }
|
||||
// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||
// if err == nil {
|
||||
// return enabled == "1", nil
|
||||
// }
|
||||
//
|
||||
// userEnabled, err := IsUserEnabled(userId)
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
// enabled = "0"
|
||||
// if userEnabled {
|
||||
// enabled = "1"
|
||||
// }
|
||||
// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
|
||||
// if err != nil {
|
||||
// common.SysError("Redis set user enabled error: " + err.Error())
|
||||
// }
|
||||
// return userEnabled, err
|
||||
//}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
@@ -344,12 +250,12 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
}
|
||||
|
||||
func CacheUpdateChannelStatus(id int, status int) {
|
||||
if (!common.MemoryCacheEnabled) {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
14
model/log.go
14
model/log.go
@@ -12,16 +12,6 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
|
||||
func init() {
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
}
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
@@ -81,7 +71,7 @@ func RecordLog(userId int, logType int, content string) {
|
||||
if logType == LogTypeConsume && !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := CacheGetUsername(userId)
|
||||
username, _ := GetUsernameById(userId, false)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -102,7 +92,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := CacheGetUsername(userId)
|
||||
username, _ := GetUsernameById(userId, false)
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
|
||||
@@ -13,6 +13,20 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
var keyCol string
|
||||
|
||||
func init() {
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
keyCol = `"key"`
|
||||
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
keyCol = "`key`"
|
||||
}
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
145
model/token.go
145
model/token.go
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
relaycommon "one-api/relay/common"
|
||||
@@ -30,6 +31,10 @@ type Token struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (token *Token) Clean() {
|
||||
token.Key = ""
|
||||
}
|
||||
|
||||
func (token *Token) GetIpLimitsMap() map[string]any {
|
||||
// delete empty spaces
|
||||
//split with \n
|
||||
@@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
if key == "" {
|
||||
return nil, errors.New("未提供令牌")
|
||||
}
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
token, err = GetTokenByKey(key, false)
|
||||
if err == nil {
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
keyPrefix := key[:3]
|
||||
@@ -128,22 +133,38 @@ func GetTokenById(id int) (*Token, error) {
|
||||
token := Token{Id: id}
|
||||
var err error = nil
|
||||
err = DB.First(&token, "id = ?", id).Error
|
||||
if err != nil {
|
||||
if common.RedisEnabled {
|
||||
go cacheSetToken(&token)
|
||||
}
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
if err := cacheSetToken(token); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func GetTokenByKey(key string) (*Token, error) {
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||||
gopool.Go(func() {
|
||||
if err := cacheSetToken(*token); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
// Try Redis first
|
||||
token, err := cacheGetTokenByKey(key)
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
var token Token
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return &token, err
|
||||
fromDB = true
|
||||
err = DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return token, err
|
||||
}
|
||||
|
||||
func (token *Token) Insert() error {
|
||||
@@ -153,20 +174,48 @@ func (token *Token) Insert() error {
|
||||
}
|
||||
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
var err error
|
||||
func (token *Token) Update() (err error) {
|
||||
defer func() {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
err := cacheSetToken(*token)
|
||||
if err != nil {
|
||||
common.SysError("failed to update token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
||||
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (token *Token) SelectUpdate() error {
|
||||
func (token *Token) SelectUpdate() (err error) {
|
||||
defer func() {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
err := cacheSetToken(*token)
|
||||
if err != nil {
|
||||
common.SysError("failed to update token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
// This can update zero values
|
||||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
||||
}
|
||||
|
||||
func (token *Token) Delete() error {
|
||||
var err error
|
||||
func (token *Token) Delete() (err error) {
|
||||
defer func() {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
err := cacheDeleteToken(token.Key)
|
||||
if err != nil {
|
||||
common.SysError("failed to delete token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
err = DB.Delete(token).Error
|
||||
return err
|
||||
}
|
||||
@@ -214,10 +263,18 @@ func DeleteTokenById(id int, userId int) (err error) {
|
||||
return token.Delete()
|
||||
}
|
||||
|
||||
func IncreaseTokenQuota(id int, quota int) (err error) {
|
||||
func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
err := cacheIncrTokenQuota(key, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to increase token quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||
return nil
|
||||
@@ -236,10 +293,18 @@ func increaseTokenQuota(id int, quota int) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func DecreaseTokenQuota(id int, quota int) (err error) {
|
||||
func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
err := cacheDecrTokenQuota(key, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to decrease token quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||
return nil
|
||||
@@ -258,37 +323,31 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
if quota < 0 {
|
||||
return 0, errors.New("quota 不能为负数!")
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if !relayInfo.IsPlayground {
|
||||
token, err := GetTokenById(relayInfo.TokenId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return 0, errors.New("令牌额度不足")
|
||||
}
|
||||
if relayInfo.IsPlayground {
|
||||
return nil
|
||||
}
|
||||
userQuota, err = GetUserQuota(relayInfo.UserId)
|
||||
//if relayInfo.TokenUnlimited {
|
||||
// return nil
|
||||
//}
|
||||
token, err := GetTokenById(relayInfo.TokenId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
if userQuota < quota {
|
||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
||||
return errors.New("令牌额度不足")
|
||||
}
|
||||
if !relayInfo.IsPlayground {
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
||||
return userQuota - quota, err
|
||||
return nil
|
||||
}
|
||||
|
||||
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||
|
||||
if quota > 0 {
|
||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
||||
@@ -301,9 +360,9 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
|
||||
|
||||
if !relayInfo.IsPlayground {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
|
||||
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
64
model/token_cache.go
Normal file
64
model/token_cache.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"time"
|
||||
)
|
||||
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheDeleteToken(key string) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheIncrTokenQuota(key string, increment int64) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheDecrTokenQuota(key string, decrement int64) error {
|
||||
return cacheIncrTokenQuota(key, -decrement)
|
||||
}
|
||||
|
||||
func cacheSetTokenField(key string, field string, value string) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取
|
||||
func cacheGetTokenByKey(key string) (*Token, error) {
|
||||
hmacKey := common.GenerateHMAC(key)
|
||||
if !common.RedisEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
var token Token
|
||||
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.Key = key
|
||||
return &token, nil
|
||||
}
|
||||
191
model/user.go
191
model/user.go
@@ -6,7 +6,8 @@ import (
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -89,11 +90,6 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
var users []*User
|
||||
var err error
|
||||
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
// 尝试将关键字转换为整数ID
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
@@ -107,7 +103,7 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
return users, err
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
err = nil
|
||||
|
||||
query := DB.Unscoped().Omit("password")
|
||||
@@ -251,14 +247,12 @@ func (user *User) Update(updatePassword bool) error {
|
||||
}
|
||||
newUser := *user
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Updates(newUser).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
}
|
||||
if err = DB.Model(user).Updates(newUser).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
|
||||
// 更新缓存
|
||||
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
||||
}
|
||||
|
||||
func (user *User) Edit(updatePassword bool) error {
|
||||
@@ -269,6 +263,7 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newUser := *user
|
||||
updates := map[string]interface{}{
|
||||
"username": newUser.Username,
|
||||
@@ -279,23 +274,26 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
if updatePassword {
|
||||
updates["password"] = newUser.Password
|
||||
}
|
||||
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Updates(updates).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
}
|
||||
if err = DB.Model(user).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
|
||||
// 更新缓存
|
||||
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
||||
}
|
||||
|
||||
func (user *User) Delete() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Delete(user).Error
|
||||
return err
|
||||
if err := DB.Delete(user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除缓存
|
||||
return invalidateUserCache(user.Id)
|
||||
}
|
||||
|
||||
func (user *User) HardDelete() error {
|
||||
@@ -409,15 +407,33 @@ func IsAdmin(userId int) bool {
|
||||
return user.Role >= common.RoleAdminUser
|
||||
}
|
||||
|
||||
func IsUserEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
|
||||
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserStatusCache(id, status); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
// Try Redis first
|
||||
status, err := getUserStatusCache(id)
|
||||
if err == nil {
|
||||
return status == common.UserStatusEnabled, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
||||
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
@@ -433,14 +449,33 @@ func ValidateAccessToken(token string) (user *User) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserQuota(id int) (quota int, err error) {
|
||||
// GetUserQuota gets quota from Redis first, falls back to DB if needed
|
||||
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserQuotaCache(id, quota); err != nil {
|
||||
common.SysError("failed to update user quota cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
quota, err := getUserQuotaCache(id)
|
||||
if err == nil {
|
||||
return quota, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
//common.SysError("failed to get user quota from cache: " + err.Error())
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||
if err != nil {
|
||||
if common.RedisEnabled {
|
||||
go cacheSetUserQuota(id, quota)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return quota, err
|
||||
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
func GetUserUsedQuota(id int) (quota int, err error) {
|
||||
@@ -453,20 +488,44 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
return email, err
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
// GetUserGroup gets group from Redis first, falls back to DB if needed
|
||||
func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserGroupCache(id, group); err != nil {
|
||||
common.SysError("failed to update user group cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
group, err := getUserGroupCache(id)
|
||||
if err == nil {
|
||||
return group, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
return group, err
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
gopool.Go(func() {
|
||||
err := cacheIncrUserQuota(id, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to increase user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||
return nil
|
||||
@@ -476,6 +535,9 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
||||
|
||||
func increaseUserQuota(id int, quota int) (err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -483,6 +545,12 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
gopool.Go(func() {
|
||||
err := cacheDecrUserQuota(id, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to decrease user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||
return nil
|
||||
@@ -492,9 +560,23 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
||||
|
||||
func decreaseUserQuota(id int, quota int) (err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
||||
if delta == 0 {
|
||||
return nil
|
||||
}
|
||||
if delta > 0 {
|
||||
return IncreaseUserQuota(id, delta)
|
||||
} else {
|
||||
return DecreaseUserQuota(id, -delta)
|
||||
}
|
||||
}
|
||||
|
||||
func GetRootUserEmail() (email string) {
|
||||
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
||||
return email
|
||||
@@ -518,7 +600,13 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//// 更新缓存
|
||||
//if err := invalidateUserCache(id); err != nil {
|
||||
// common.SysError("failed to invalidate user cache: " + err.Error())
|
||||
//}
|
||||
}
|
||||
|
||||
func updateUserUsedQuota(id int, quota int) {
|
||||
@@ -539,9 +627,32 @@ func updateUserRequestCount(id int, count int) {
|
||||
}
|
||||
}
|
||||
|
||||
func GetUsernameById(id int) (username string, err error) {
|
||||
// GetUsernameById gets username from Redis first, falls back to DB if needed
|
||||
func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserNameCache(id, username); err != nil {
|
||||
common.SysError("failed to update user name cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
username, err := getUserNameCache(id)
|
||||
if err == nil {
|
||||
return username, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
|
||||
return username, err
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
|
||||
|
||||
206
model/user_cache.go
Normal file
206
model/user_cache.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Change UserCache struct to userCache
|
||||
type userCache struct {
|
||||
Id int `json:"id"`
|
||||
Group string `json:"group"`
|
||||
Quota int `json:"quota"`
|
||||
Status int `json:"status"`
|
||||
Role int `json:"role"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// Rename all exported functions to private ones
|
||||
// invalidateUserCache clears all user related cache
|
||||
func invalidateUserCache(userId int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := []string{
|
||||
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
||||
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
||||
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
||||
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if err := common.RedisDel(key); err != nil {
|
||||
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateUserGroupCache updates user group cache
|
||||
func updateUserGroupCache(userId int, group string) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
||||
group,
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserQuotaCache updates user quota cache
|
||||
func updateUserQuotaCache(userId int, quota int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
||||
fmt.Sprintf("%d", quota),
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserStatusCache updates user status cache
|
||||
func updateUserStatusCache(userId int, userEnabled bool) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
enabled := "0"
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
||||
enabled,
|
||||
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserNameCache updates username cache
|
||||
func updateUserNameCache(userId int, username string) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
||||
username,
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserCache updates all user cache fields
|
||||
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := updateUserGroupCache(userId, userGroup); err != nil {
|
||||
return fmt.Errorf("update group cache: %w", err)
|
||||
}
|
||||
|
||||
if err := updateUserQuotaCache(userId, quota); err != nil {
|
||||
return fmt.Errorf("update quota cache: %w", err)
|
||||
}
|
||||
|
||||
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
|
||||
return fmt.Errorf("update status cache: %w", err)
|
||||
}
|
||||
|
||||
if err := updateUserNameCache(userId, username); err != nil {
|
||||
return fmt.Errorf("update username cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUserGroupCache gets user group from cache
|
||||
func getUserGroupCache(userId int) (string, error) {
|
||||
if !common.RedisEnabled {
|
||||
return "", nil
|
||||
}
|
||||
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
|
||||
}
|
||||
|
||||
// getUserQuotaCache gets user quota from cache
|
||||
func getUserQuotaCache(userId int) (int, error) {
|
||||
if !common.RedisEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(quotaStr)
|
||||
}
|
||||
|
||||
// getUserStatusCache gets user status from cache
|
||||
func getUserStatusCache(userId int) (int, error) {
|
||||
if !common.RedisEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(statusStr)
|
||||
}
|
||||
|
||||
// getUserNameCache gets username from cache
|
||||
func getUserNameCache(userId int) (string, error) {
|
||||
if !common.RedisEnabled {
|
||||
return "", nil
|
||||
}
|
||||
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
|
||||
}
|
||||
|
||||
// getUserCache gets complete user cache
|
||||
func getUserCache(userId int) (*userCache, error) {
|
||||
if !common.RedisEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
group, err := getUserGroupCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group cache: %w", err)
|
||||
}
|
||||
|
||||
quota, err := getUserQuotaCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get quota cache: %w", err)
|
||||
}
|
||||
|
||||
status, err := getUserStatusCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get status cache: %w", err)
|
||||
}
|
||||
|
||||
username, err := getUserNameCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get username cache: %w", err)
|
||||
}
|
||||
|
||||
return &userCache{
|
||||
Id: userId,
|
||||
Group: group,
|
||||
Quota: quota,
|
||||
Status: status,
|
||||
Username: username,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add atomic quota operations
|
||||
func cacheIncrUserQuota(userId int, delta int64) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
|
||||
return common.RedisIncr(key, delta)
|
||||
}
|
||||
|
||||
func cacheDecrUserQuota(userId int, delta int64) error {
|
||||
return cacheIncrUserQuota(userId, -delta)
|
||||
}
|
||||
@@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) {
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func shouldUpdateRedis(fromDB bool, err error) bool {
|
||||
return common.RedisEnabled && fromDB && err == nil
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -203,13 +204,13 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
@@ -279,57 +280,97 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
|
||||
return v
|
||||
}
|
||||
|
||||
// func (g *GeminiChatResponse) GetResponseText() string {
|
||||
// if g == nil {
|
||||
// return ""
|
||||
// }
|
||||
// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
// return g.Candidates[0].Content.Parts[0].Text
|
||||
// }
|
||||
// return ""
|
||||
// }
|
||||
func unescapeString(s string) (string, error) {
|
||||
var result []rune
|
||||
escaped := false
|
||||
i := 0
|
||||
|
||||
for i < len(s) {
|
||||
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
|
||||
if r == utf8.RuneError {
|
||||
return "", fmt.Errorf("invalid UTF-8 encoding")
|
||||
}
|
||||
|
||||
if escaped {
|
||||
// 如果是转义符后的字符,检查其类型
|
||||
switch r {
|
||||
case '"':
|
||||
result = append(result, '"')
|
||||
case '\\':
|
||||
result = append(result, '\\')
|
||||
case '/':
|
||||
result = append(result, '/')
|
||||
case 'b':
|
||||
result = append(result, '\b')
|
||||
case 'f':
|
||||
result = append(result, '\f')
|
||||
case 'n':
|
||||
result = append(result, '\n')
|
||||
case 'r':
|
||||
result = append(result, '\r')
|
||||
case 't':
|
||||
result = append(result, '\t')
|
||||
case '\'':
|
||||
result = append(result, '\'')
|
||||
default:
|
||||
// 如果遇到一个非法的转义字符,直接按原样输出
|
||||
result = append(result, '\\', r)
|
||||
}
|
||||
escaped = false
|
||||
} else {
|
||||
if r == '\\' {
|
||||
escaped = true // 记录反斜杠作为转义符
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
i += size // 移动到下一个字符
|
||||
}
|
||||
|
||||
return string(result), nil
|
||||
}
|
||||
func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, val := range v {
|
||||
v[k] = unescapeMapOrSlice(val)
|
||||
}
|
||||
case []interface{}:
|
||||
for i, val := range v {
|
||||
v[i] = unescapeMapOrSlice(val)
|
||||
}
|
||||
case string:
|
||||
if unescaped, err := unescapeString(v); err != nil {
|
||||
return v
|
||||
} else {
|
||||
return unescaped
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
|
||||
} else {
|
||||
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
//common.SysError("getToolCall failed: " + err.Error())
|
||||
return nil
|
||||
}
|
||||
return &dto.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
// 不好评价,得去转义一下反斜杠,Gemini 的特性好像是,Google 返回的时候本身就会转义“\”
|
||||
Arguments: strings.ReplaceAll(string(argsBytes), "\\\\", "\\"),
|
||||
Arguments: string(argsBytes),
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
|
||||
// var toolCalls []dto.ToolCall
|
||||
|
||||
// item := candidate.Content.Parts[index]
|
||||
// if item.FunctionCall == nil {
|
||||
// return toolCalls
|
||||
// }
|
||||
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
// if err != nil {
|
||||
// //common.SysError("getToolCalls failed: " + err.Error())
|
||||
// return toolCalls
|
||||
// }
|
||||
// toolCall := dto.ToolCall{
|
||||
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
// Type: "function",
|
||||
// Function: dto.FunctionCall{
|
||||
// Arguments: string(argsBytes),
|
||||
// Name: item.FunctionCall.FunctionName,
|
||||
// },
|
||||
// }
|
||||
// toolCalls = append(toolCalls, toolCall)
|
||||
// return toolCalls
|
||||
// }
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
|
||||
@@ -77,27 +77,13 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
|
||||
@@ -100,7 +100,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
}
|
||||
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
|
||||
sizeRatio := 1.0
|
||||
// Size
|
||||
|
||||
@@ -170,7 +170,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -194,11 +194,11 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
//err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
@@ -476,7 +476,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -500,14 +500,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||
|
||||
@@ -262,7 +262,7 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
|
||||
|
||||
// 预扣费并返回用户剩余配额
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -272,10 +272,6 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
@@ -293,11 +289,16 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
||||
}
|
||||
}
|
||||
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
}
|
||||
@@ -307,7 +308,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
||||
go func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
|
||||
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
@@ -365,15 +366,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
// 预扣
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -113,14 +113,10 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
// release quota
|
||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||
|
||||
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
|
||||
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
|
||||
|
||||
@@ -5,11 +5,12 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
@@ -31,6 +32,31 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func DecodeBase64FileData(base64String string) (string, string, error) {
|
||||
var mimeType string
|
||||
var idx int
|
||||
idx = strings.Index(base64String, ",")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = base64String[:idx]
|
||||
base64String = base64String[idx+1:]
|
||||
idx = strings.Index(mimeType, ";")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = mimeType[:idx]
|
||||
idx = strings.Index(mimeType, ":")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = mimeType[idx+1:]
|
||||
return mimeType, base64String, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoDownloadRequest(url)
|
||||
|
||||
@@ -18,12 +18,12 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
if relayInfo.UsePrice {
|
||||
return nil
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"))
|
||||
token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -58,15 +58,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
|
||||
}
|
||||
|
||||
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
|
||||
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -120,7 +116,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
//}
|
||||
//quotaDelta := quota - preConsumedQuota
|
||||
//if quotaDelta != 0 {
|
||||
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
// err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
// if err != nil {
|
||||
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
// }
|
||||
@@ -190,15 +186,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
} else {
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -146,8 +146,9 @@ const PersonalSetting = () => {
|
||||
let res = await API.get(`/api/user/models`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
setModels(data);
|
||||
console.log(data);
|
||||
if (data != null) {
|
||||
setModels(data);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import i18next from 'i18next';
|
||||
import { Modal, Tag } from '@douyinfe/semi-ui';
|
||||
import { Modal, Tag, Typography } from '@douyinfe/semi-ui';
|
||||
import { copy, showSuccess } from './utils.js';
|
||||
|
||||
export function renderText(text, limit) {
|
||||
@@ -55,6 +55,81 @@ export function renderGroup(group) {
|
||||
);
|
||||
}
|
||||
|
||||
export function renderRatio(ratio) {
|
||||
let color = 'green';
|
||||
if (ratio > 5) {
|
||||
color = 'red';
|
||||
} else if (ratio > 3) {
|
||||
color = 'orange';
|
||||
} else if (ratio > 1) {
|
||||
color = 'blue';
|
||||
}
|
||||
return <Tag color={color}>{ratio}x {i18next.t('倍率')}</Tag>;
|
||||
}
|
||||
|
||||
export const renderGroupOption = (item) => {
|
||||
const {
|
||||
disabled,
|
||||
selected,
|
||||
label,
|
||||
value,
|
||||
focused,
|
||||
className,
|
||||
style,
|
||||
onMouseEnter,
|
||||
onClick,
|
||||
empty,
|
||||
emptyContent,
|
||||
...rest
|
||||
} = item;
|
||||
|
||||
const baseStyle = {
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
padding: '8px 16px',
|
||||
cursor: disabled ? 'not-allowed' : 'pointer',
|
||||
backgroundColor: focused ? 'var(--semi-color-fill-0)' : 'transparent',
|
||||
opacity: disabled ? 0.5 : 1,
|
||||
...(selected && {
|
||||
backgroundColor: 'var(--semi-color-primary-light-default)',
|
||||
}),
|
||||
'&:hover': {
|
||||
backgroundColor: !disabled && 'var(--semi-color-fill-1)'
|
||||
}
|
||||
};
|
||||
|
||||
const handleClick = () => {
|
||||
if (!disabled && onClick) {
|
||||
onClick();
|
||||
}
|
||||
};
|
||||
|
||||
const handleMouseEnter = (e) => {
|
||||
if (!disabled && onMouseEnter) {
|
||||
onMouseEnter(e);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
style={baseStyle}
|
||||
onClick={handleClick}
|
||||
onMouseEnter={handleMouseEnter}
|
||||
>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
|
||||
<Typography.Text strong type={disabled ? 'tertiary' : undefined}>
|
||||
{value}
|
||||
</Typography.Text>
|
||||
<Typography.Text type="secondary" size="small">
|
||||
{label}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
{item.ratio && renderRatio(item.ratio)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export function renderNumber(num) {
|
||||
if (num >= 1000000000) {
|
||||
return (num / 1000000000).toFixed(1) + 'B';
|
||||
@@ -352,7 +427,7 @@ export const modelColorMap = {
|
||||
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
|
||||
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
|
||||
'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色
|
||||
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃<EFBFBD><EFBFBD><EFBFBD>
|
||||
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃
|
||||
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
|
||||
'gpt-4': 'rgb(135,206,235)', // 天蓝色
|
||||
// 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
|
||||
@@ -375,7 +450,7 @@ export const modelColorMap = {
|
||||
'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红
|
||||
'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别)
|
||||
'text-moderation-latest': 'rgb(255,130,171)', // 强粉色
|
||||
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(<EFBFBD><EFBFBD><EFBFBD>Babbage相同,表示同一类功能)
|
||||
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(与Babbage相同,表示同一类功能)
|
||||
'tts-1': 'rgb(255,140,0)', // 深橙色
|
||||
'tts-1-1106': 'rgb(255,165,0)', // 橙色
|
||||
'tts-1-hd': 'rgb(255,215,0)', // 金色
|
||||
|
||||
@@ -49,8 +49,18 @@ export async function copy(text) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} catch (e) {
|
||||
okay = false;
|
||||
console.error(e);
|
||||
try {
|
||||
// 构建input 执行 复制命令
|
||||
var _input = window.document.createElement("input");
|
||||
_input.value = text;
|
||||
window.document.body.appendChild(_input);
|
||||
_input.select();
|
||||
window.document.execCommand("Copy");
|
||||
window.document.body.removeChild(_input);
|
||||
} catch (e) {
|
||||
okay = false;
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return okay;
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ import React, { useCallback, useContext, useEffect, useState } from 'react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../../context/User/index.js';
|
||||
import { API, getUserIdFromLocalStorage, showError } from '../../helpers/index.js';
|
||||
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography, Button } from '@douyinfe/semi-ui';
|
||||
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography, Button, Highlight } from '@douyinfe/semi-ui';
|
||||
import { SSE } from 'sse';
|
||||
import { IconSetting } from '@douyinfe/semi-icons';
|
||||
import { StyleContext } from '../../context/Style/index.js';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { renderGroupOption } from '../../helpers/render.js';
|
||||
|
||||
const roleInfo = {
|
||||
user: {
|
||||
@@ -97,15 +98,17 @@ const Playground = () => {
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
|
||||
label: info.desc,
|
||||
value: group,
|
||||
ratio: info.ratio
|
||||
}));
|
||||
|
||||
if (localGroupOptions.length === 0) {
|
||||
localGroupOptions = [{
|
||||
label: t('用户分组'),
|
||||
value: '',
|
||||
ratio: 1
|
||||
}];
|
||||
} else {
|
||||
const localUser = JSON.parse(localStorage.getItem('user'));
|
||||
@@ -326,12 +329,9 @@ const Playground = () => {
|
||||
}}
|
||||
value={inputs.group}
|
||||
autoComplete='new-password'
|
||||
optionList={groups.map((group) => ({
|
||||
...group,
|
||||
label: styleState.isMobile && group.label.length > 16
|
||||
? group.label.substring(0, 16) + '...'
|
||||
: group.label,
|
||||
}))}
|
||||
optionList={groups}
|
||||
renderOptionItem={renderGroupOption}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>{t('模型')}:</Typography.Text>
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
showSuccess,
|
||||
timestamp2string,
|
||||
} from '../../helpers';
|
||||
import { renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import { renderGroupOption, renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import {
|
||||
AutoComplete,
|
||||
Banner,
|
||||
@@ -97,9 +97,10 @@ const EditToken = (props) => {
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
|
||||
label: info.desc,
|
||||
value: group,
|
||||
ratio: info.ratio
|
||||
}));
|
||||
setGroups(localGroupOptions);
|
||||
} else {
|
||||
@@ -449,6 +450,8 @@ const EditToken = (props) => {
|
||||
onChange={(value) => {
|
||||
handleInputChange('group', value);
|
||||
}}
|
||||
position={'topLeft'}
|
||||
renderOptionItem={renderGroupOption}
|
||||
value={inputs.group}
|
||||
autoComplete='new-password'
|
||||
optionList={groups}
|
||||
|
||||
Reference in New Issue
Block a user