mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-04 22:00:30 +00:00
Compare commits
90 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4f17543cb | ||
|
|
1eb706de7a | ||
|
|
d13d81baba | ||
|
|
65af1a4d10 | ||
|
|
1ae0a3fb83 | ||
|
|
fe2e8f1a42 | ||
|
|
a5f7f8af29 | ||
|
|
2f01a2125f | ||
|
|
e4f9787c16 | ||
|
|
bb5e032dd2 | ||
|
|
304c92ceab | ||
|
|
05874dcca5 | ||
|
|
ca8b7ed1c3 | ||
|
|
ed435e5c8f | ||
|
|
a1b864bc5e | ||
|
|
2a15dfccea | ||
|
|
9e5a7ed541 | ||
|
|
65d1cde8fb | ||
|
|
8f4a2df5ee | ||
|
|
2b38e8ed8d | ||
|
|
d75ecfc63e | ||
|
|
91b777f33f | ||
|
|
72dc54309c | ||
|
|
458dd1bd9d | ||
|
|
38cff317a0 | ||
|
|
c8614f9890 | ||
|
|
10d896aa7f | ||
|
|
118eb362c4 | ||
|
|
52c023a1dd | ||
|
|
1cef91a741 | ||
|
|
77861e6440 | ||
|
|
5f082d72bb | ||
|
|
0fd0e5d309 | ||
|
|
d2297d2723 | ||
|
|
62ae46b552 | ||
|
|
0b1354ed51 | ||
|
|
132c71390c | ||
|
|
bb3deb7b93 | ||
|
|
f92d96e298 | ||
|
|
c86762b656 | ||
|
|
3409d7a6b6 | ||
|
|
bfba4866a5 | ||
|
|
4fc1fe318e | ||
|
|
b3576f24ef | ||
|
|
ed4d26fc9e | ||
|
|
ba56e2e8ca | ||
|
|
7c20e6d047 | ||
|
|
72d6898eb5 | ||
|
|
f2c9388139 | ||
|
|
aaf5cecefd | ||
|
|
fe2165ace6 | ||
|
|
3003d12a20 | ||
|
|
a8a2195ab1 | ||
|
|
d40e6ec25d | ||
|
|
8129aa76f9 | ||
|
|
fb8595da18 | ||
|
|
93cda60d44 | ||
|
|
2ec5eafbce | ||
|
|
be0c240e97 | ||
|
|
7180e6f114 | ||
|
|
61495a460a | ||
|
|
cf3287a10a | ||
|
|
f3f1817aea | ||
|
|
a4795737fe | ||
|
|
eec8f523ce | ||
|
|
58fac129d6 | ||
|
|
241c9389ef | ||
|
|
1d0ef89ce9 | ||
|
|
cce2990db6 | ||
|
|
a7e1d17c3e | ||
|
|
c4e256e69b | ||
|
|
87a5e40daf | ||
|
|
0c326556aa | ||
|
|
794f6a6e34 | ||
|
|
656e809202 | ||
|
|
53ab2aaee4 | ||
|
|
a02bc3342f | ||
|
|
f54d0cb3b0 | ||
|
|
a5c48c2772 | ||
|
|
cffaf0d636 | ||
|
|
865b98a454 | ||
|
|
5bdbf3a673 | ||
|
|
43a7b59b68 | ||
|
|
eac3463401 | ||
|
|
1fa478af20 | ||
|
|
3b58b4989d | ||
|
|
0b1ba2eeb9 | ||
|
|
35277f2b4a | ||
|
|
03256dbdad | ||
|
|
f9a7f6085e |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
||||
.github
|
||||
.git
|
||||
*.md
|
||||
.vscode
|
||||
.gitignore
|
||||
Makefile
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -8,3 +8,5 @@ build
|
||||
logs
|
||||
web/dist
|
||||
.env
|
||||
one-api
|
||||
.DS_Store
|
||||
@@ -81,6 +81,9 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
||||
- `UPDATE_TASK`: Update async tasks (Midjourney, Suno), default `true`
|
||||
- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
|
||||
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
|
||||
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
|
||||
- `CRYPTO_SECRET`: Encryption key for encrypting database content
|
||||
|
||||
## Deployment
|
||||
> [!TIP]
|
||||
@@ -91,6 +94,10 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
|
||||
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
> ```
|
||||
|
||||
### Multi-Server Deployment
|
||||
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
|
||||
- If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment.
|
||||
|
||||
### Requirements
|
||||
- Local database (default): SQLite (Docker deployment must mount `/data` directory)
|
||||
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6
|
||||
|
||||
@@ -87,6 +87,9 @@
|
||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
|
||||
- `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`,`STRICT`,默认为 `NONE`。
|
||||
- `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
|
||||
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
|
||||
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
|
||||
## 部署
|
||||
> [!TIP]
|
||||
> 最新版Docker镜像:`calciumion/new-api:latest`
|
||||
@@ -96,6 +99,10 @@
|
||||
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
|
||||
> ```
|
||||
|
||||
### 多机部署
|
||||
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
|
||||
- 如果公用Redis,必须设置 `CRYPTO_SECRET`,否则会导致多机部署时Redis内容无法获取。
|
||||
|
||||
### 部署要求
|
||||
- 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机)
|
||||
- 远程数据库:MySQL 版本 >= 5.7.8,PgSQL 版本 >= 9.6
|
||||
|
||||
@@ -30,6 +30,7 @@ var DefaultCollapseSidebar = false // default value of collapse sidebar
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
var CryptoSecret = uuid.New().String()
|
||||
|
||||
var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
|
||||
@@ -1,6 +1,23 @@
|
||||
package common
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func GenerateHMACWithKey(key []byte, data string) string {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func GenerateHMAC(data string) string {
|
||||
h := hmac.New(sha256.New, []byte(CryptoSecret))
|
||||
h.Write([]byte(data))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func Password2Hash(password string) (string, error) {
|
||||
passwordBytes := []byte(password)
|
||||
|
||||
@@ -22,7 +22,7 @@ func printHelp() {
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
}
|
||||
|
||||
func init() {
|
||||
func LoadEnv() {
|
||||
flag.Parse()
|
||||
|
||||
if *PrintVersion {
|
||||
@@ -45,6 +45,11 @@ func init() {
|
||||
SessionSecret = ss
|
||||
}
|
||||
}
|
||||
if os.Getenv("CRYPTO_SECRET") != "" {
|
||||
CryptoSecret = os.Getenv("CRYPTO_SECRET")
|
||||
} else {
|
||||
CryptoSecret = SessionSecret
|
||||
}
|
||||
if os.Getenv("SQLITE_PATH") != "" {
|
||||
SQLitePath = os.Getenv("SQLITE_PATH")
|
||||
}
|
||||
|
||||
@@ -46,6 +46,8 @@ var defaultModelRatio = map[string]float64{
|
||||
"gpt-4o-2024-08-06": 1.25, // $2.5 / 1M tokens
|
||||
"gpt-4o-2024-11-20": 1.25, // $2.5 / 1M tokens
|
||||
"gpt-4o-realtime-preview": 2.5,
|
||||
"o1": 7.5,
|
||||
"o1-2024-12-17": 7.5,
|
||||
"o1-preview": 7.5,
|
||||
"o1-preview-2024-09-12": 7.5,
|
||||
"o1-mini": 1.5,
|
||||
@@ -354,7 +356,7 @@ func GetCompletionRatio(name string) float64 {
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "o1-") {
|
||||
if strings.HasPrefix(name, "o1") {
|
||||
return 4
|
||||
}
|
||||
if name == "chatgpt-4o-latest" {
|
||||
|
||||
216
common/redis.go
216
common/redis.go
@@ -2,9 +2,15 @@ package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var RDB *redis.Client
|
||||
@@ -56,39 +62,167 @@ func RedisGet(key string) (string, error) {
|
||||
return RDB.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func RedisExpire(key string, expiration time.Duration) error {
|
||||
ctx := context.Background()
|
||||
return RDB.Expire(ctx, key, expiration).Err()
|
||||
}
|
||||
|
||||
func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
||||
ctx := context.Background()
|
||||
return RDB.GetSet(ctx, key, expiration).Result()
|
||||
}
|
||||
//func RedisExpire(key string, expiration time.Duration) error {
|
||||
// ctx := context.Background()
|
||||
// return RDB.Expire(ctx, key, expiration).Err()
|
||||
//}
|
||||
//
|
||||
//func RedisGetEx(key string, expiration time.Duration) (string, error) {
|
||||
// ctx := context.Background()
|
||||
// return RDB.GetSet(ctx, key, expiration).Result()
|
||||
//}
|
||||
|
||||
func RedisDel(key string) error {
|
||||
ctx := context.Background()
|
||||
return RDB.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisDecrease(key string, value int64) error {
|
||||
func RedisHDelObj(key string) error {
|
||||
ctx := context.Background()
|
||||
return RDB.HDel(ctx, key).Err()
|
||||
}
|
||||
|
||||
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
ctx := context.Background()
|
||||
|
||||
data := make(map[string]interface{})
|
||||
|
||||
// 使用反射遍历结构体字段
|
||||
v := reflect.ValueOf(obj).Elem()
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
value := v.Field(i)
|
||||
|
||||
// Skip DeletedAt field
|
||||
if field.Type.String() == "gorm.DeletedAt" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理指针类型
|
||||
if value.Kind() == reflect.Ptr {
|
||||
if value.IsNil() {
|
||||
data[field.Name] = ""
|
||||
continue
|
||||
}
|
||||
value = value.Elem()
|
||||
}
|
||||
|
||||
// 处理布尔类型
|
||||
if value.Kind() == reflect.Bool {
|
||||
data[field.Name] = strconv.FormatBool(value.Bool())
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他类型直接转换为字符串
|
||||
data[field.Name] = fmt.Sprintf("%v", value.Interface())
|
||||
}
|
||||
|
||||
txn := RDB.TxPipeline()
|
||||
txn.HSet(ctx, key, data)
|
||||
txn.Expire(ctx, key, expiration)
|
||||
|
||||
_, err := txn.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute transaction: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHGetObj(key string, obj interface{}) error {
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := RDB.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load hash from Redis: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return fmt.Errorf("key %s not found in Redis", key)
|
||||
}
|
||||
|
||||
// Handle both pointer and non-pointer values
|
||||
val := reflect.ValueOf(obj)
|
||||
if val.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
|
||||
}
|
||||
|
||||
v := val.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
fieldName := field.Name
|
||||
if value, ok := result[fieldName]; ok {
|
||||
fieldValue := v.Field(i)
|
||||
|
||||
// Handle pointer types
|
||||
if fieldValue.Kind() == reflect.Ptr {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
// Enhanced type handling for Token struct
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.String:
|
||||
fieldValue.SetString(value)
|
||||
case reflect.Int, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.SetInt(intValue)
|
||||
case reflect.Bool:
|
||||
boolValue, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.SetBool(boolValue)
|
||||
case reflect.Struct:
|
||||
// Special handling for gorm.DeletedAt
|
||||
if fieldValue.Type().String() == "gorm.DeletedAt" {
|
||||
if value != "" {
|
||||
timeValue, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
|
||||
}
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RedisIncr Add this function to handle atomic increments
|
||||
func RedisIncr(key string, delta int64) error {
|
||||
// 检查键的剩余生存时间
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil {
|
||||
// 失败则尝试直接减少
|
||||
return RDB.DecrBy(context.Background(), key, value).Err()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
// 如果剩余生存时间大于0,则进行减少操作
|
||||
// 只有在 key 存在且有 TTL 时才需要特殊处理
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
// 开始一个Redis事务
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
// 减少余额
|
||||
decrCmd := txn.DecrBy(ctx, key, value)
|
||||
decrCmd := txn.IncrBy(ctx, key, delta)
|
||||
if err := decrCmd.Err(); err != nil {
|
||||
return err // 如果减少失败,则直接返回错误
|
||||
}
|
||||
@@ -99,8 +233,54 @@ func RedisDecrease(key string, value int64) error {
|
||||
// 执行事务
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
} else {
|
||||
_ = RedisDel(key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHIncrBy(key, field string, delta int64) error {
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
incrCmd := txn.HIncrBy(ctx, key, field, delta)
|
||||
if err := incrCmd.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RedisHSetField(key, field string, value interface{}) error {
|
||||
ttlCmd := RDB.TTL(context.Background(), key)
|
||||
ttl, err := ttlCmd.Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return fmt.Errorf("failed to get TTL: %w", err)
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
ctx := context.Background()
|
||||
txn := RDB.TxPipeline()
|
||||
|
||||
hsetCmd := txn.HSet(ctx, key, field, value)
|
||||
if err := hsetCmd.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn.Expire(ctx, key, ttl)
|
||||
|
||||
_, err = txn.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
var UserUsableGroups = map[string]string{
|
||||
"default": "默认分组",
|
||||
"vip": "vip分组",
|
||||
}
|
||||
|
||||
func UserUsableGroups2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(UserUsableGroups)
|
||||
if err != nil {
|
||||
SysError("error marshalling user groups: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
|
||||
UserUsableGroups = make(map[string]string)
|
||||
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
|
||||
}
|
||||
|
||||
func GetUserUsableGroups(userGroup string) map[string]string {
|
||||
if userGroup == "" {
|
||||
// 如果userGroup为空,返回UserUsableGroups
|
||||
return UserUsableGroups
|
||||
}
|
||||
// 如果userGroup不在UserUsableGroups中,返回UserUsableGroups + userGroup
|
||||
if _, ok := UserUsableGroups[userGroup]; !ok {
|
||||
appendUserUsableGroups := make(map[string]string)
|
||||
for k, v := range UserUsableGroups {
|
||||
appendUserUsableGroups[k] = v
|
||||
}
|
||||
appendUserUsableGroups[userGroup] = "用户分组"
|
||||
return appendUserUsableGroups
|
||||
}
|
||||
// 如果userGroup在UserUsableGroups中,返回UserUsableGroups
|
||||
return UserUsableGroups
|
||||
}
|
||||
|
||||
func GroupInUserUsableGroups(groupName string) bool {
|
||||
_, ok := UserUsableGroups[groupName]
|
||||
return ok
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"html/template"
|
||||
"log"
|
||||
"math/big"
|
||||
@@ -15,6 +14,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func OpenBrowser(url string) {
|
||||
|
||||
23
constant/cache_key.go
Normal file
23
constant/cache_key.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
UserGroupKeyFmt = "user_group:%d"
|
||||
UserQuotaKeyFmt = "user_quota:%d"
|
||||
UserEnabledKeyFmt = "user_enabled:%d"
|
||||
UserUsernameKeyFmt = "user_name:%d"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenFiledRemainQuota = "RemainQuota"
|
||||
TokenFieldGroup = "Group"
|
||||
)
|
||||
5
constant/context_key.go
Normal file
5
constant/context_key.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
ContextKeyRequestStartTime = "request_start_time"
|
||||
)
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
|
||||
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
|
||||
@@ -23,6 +25,8 @@ var GeminiModelMap = map[string]string{
|
||||
"gemini-1.0-pro": "v1",
|
||||
}
|
||||
|
||||
var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
func InitEnv() {
|
||||
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||
if modelVersionMapStr == "" {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package constant
|
||||
|
||||
var (
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonLength = "length"
|
||||
FinishReasonFunctionCall = "function_call"
|
||||
FinishReasonContentFilter = "content_filter"
|
||||
)
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
|
||||
@@ -21,7 +21,7 @@ func GetSubscription(c *gin.Context) {
|
||||
usedQuota = token.UsedQuota
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
remainQuota, err = model.GetUserQuota(userId)
|
||||
remainQuota, err = model.GetUserQuota(userId, false)
|
||||
usedQuota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if expiredTime <= 0 {
|
||||
|
||||
@@ -141,7 +141,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@@ -151,8 +152,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
|
||||
Model: "", // this will be set later
|
||||
Stream: false,
|
||||
}
|
||||
if strings.HasPrefix(model, "o1-") {
|
||||
testRequest.MaxCompletionTokens = 1
|
||||
if strings.HasPrefix(model, "o1") {
|
||||
testRequest.MaxCompletionTokens = 10
|
||||
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
|
||||
testRequest.MaxTokens = 2
|
||||
} else {
|
||||
|
||||
@@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if channel.Type != common.ChannelTypeOpenAI {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "仅支持 OpenAI 类型渠道",
|
||||
})
|
||||
return
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
result := OpenAIModelsResponse{}
|
||||
err = json.Unmarshal(body, &result)
|
||||
if err != nil {
|
||||
|
||||
var result OpenAIModelsResponse
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
}
|
||||
if !result.Success {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "上游返回错误",
|
||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var ids []string
|
||||
@@ -417,7 +419,8 @@ func EditTagChannels(c *gin.Context) {
|
||||
}
|
||||
|
||||
type ChannelBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
Ids []int `json:"ids"`
|
||||
Tag *string `json:"tag"`
|
||||
}
|
||||
|
||||
func DeleteChannelBatch(c *gin.Context) {
|
||||
@@ -492,3 +495,105 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func FetchModels(c *gin.Context) {
|
||||
var req struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Key string `json:"key"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"message": "Invalid request",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := req.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
url := fmt.Sprintf("%s/v1/models", baseURL)
|
||||
|
||||
request, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", "Bearer "+req.Key)
|
||||
|
||||
response, err := client.Do(request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
//check status code
|
||||
if response.StatusCode != http.StatusOK {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "Failed to fetch models",
|
||||
})
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var models []string
|
||||
for _, model := range result.Data {
|
||||
models = append(models, model.ID)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
func BatchSetChannelTag(c *gin.Context) {
|
||||
channelBatch := ChannelBatch{}
|
||||
err := c.ShouldBindJSON(&channelBatch)
|
||||
if err != nil || len(channelBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": len(channelBatch.Ids),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,13 +3,13 @@ package controller
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -20,15 +20,18 @@ func GetGroups(c *gin.Context) {
|
||||
}
|
||||
|
||||
func GetUserGroups(c *gin.Context) {
|
||||
usableGroups := make(map[string]string)
|
||||
usableGroups := make(map[string]map[string]interface{})
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.CacheGetUserGroup(userId)
|
||||
for groupName, _ := range common.GroupRatio {
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := common.GetUserUsableGroups(userGroup)
|
||||
if _, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = userUsableGroups[groupName]
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
usableGroups[groupName] = map[string]interface{}{
|
||||
"ratio": ratio,
|
||||
"desc": desc,
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -25,7 +25,8 @@ func GetAllLogs(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -63,7 +64,8 @@ func GetUserLogs(c *gin.Context) {
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
|
||||
group := c.Query("group")
|
||||
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -146,7 +148,8 @@ func GetLogsStat(c *gin.Context) {
|
||||
username := c.Query("username")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
group := c.Query("group")
|
||||
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -168,7 +171,8 @@ func GetLogsSelfStat(c *gin.Context) {
|
||||
tokenName := c.Query("token_name")
|
||||
modelName := c.Query("model_name")
|
||||
channel, _ := strconv.Atoi(c.Query("channel"))
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
|
||||
group := c.Query("group")
|
||||
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
|
||||
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
|
||||
@@ -10,10 +10,10 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
@@ -231,9 +231,9 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -263,9 +263,9 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if constant.MjForwardUrlEnabled {
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -47,9 +47,9 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": constant.ServerAddress,
|
||||
"price": constant.Price,
|
||||
"min_topup": constant.MinTopUp,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -63,9 +63,9 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
|
||||
"mj_notify_enabled": constant.MjNotifyEnabled,
|
||||
"chats": constant.Chats,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -207,7 +207,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
@@ -166,7 +166,7 @@ func ListModels(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
userId := c.GetInt("id")
|
||||
userGroup, err := model.GetUserGroup(userId)
|
||||
userGroup, err := model.GetUserGroup(userId, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = common.CheckGroupRatio(option.Value)
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
69
controller/playground.go
Normal file
69
controller/playground.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
group := playgroundRequest.Group
|
||||
userGroup := c.GetString("group")
|
||||
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
Relay(c)
|
||||
}
|
||||
@@ -4,14 +4,38 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
pricing := model.GetPricing()
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
if exists {
|
||||
user, err := model.GetUserById(userId.(int), false)
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
}
|
||||
}
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": common.GroupRatio,
|
||||
"success": true,
|
||||
"data": pricing,
|
||||
"group_ratio": groupRatio,
|
||||
"usable_group": usableGroup,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -48,58 +48,6 @@ func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErr
|
||||
return err
|
||||
}
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.Error,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
useAccessToken := c.GetBool("use_access_token")
|
||||
if useAccessToken {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
playgroundRequest := &dto.PlayGroundRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, playgroundRequest)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if playgroundRequest.Model == "" {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.Set("original_model", playgroundRequest.Model)
|
||||
group := playgroundRequest.Group
|
||||
userGroup := c.GetString("group")
|
||||
|
||||
if group == "" {
|
||||
group = userGroup
|
||||
} else {
|
||||
if !common.GroupInUserUsableGroups(group) && group != userGroup {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
|
||||
@@ -153,7 +153,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -28,13 +28,13 @@ type AmountRequest struct {
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
|
||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: constant.EpayId,
|
||||
Key: constant.EpayKey,
|
||||
}, constant.PayAddress)
|
||||
PartnerID: setting.EpayId,
|
||||
Key: setting.EpayKey,
|
||||
}, setting.PayAddress)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -50,12 +50,12 @@ func getPayMoney(amount float64, group string) float64 {
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * constant.Price * topupGroupRatio
|
||||
payMoney := amount * setting.Price * topupGroupRatio
|
||||
return payMoney
|
||||
}
|
||||
|
||||
func getMinTopup() int {
|
||||
minTopup := constant.MinTopUp
|
||||
minTopup := setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
minTopup = minTopup * int(common.QuotaPerUnit)
|
||||
}
|
||||
@@ -75,7 +75,7 @@ func RequestEpay(c *gin.Context) {
|
||||
}
|
||||
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
@@ -94,7 +94,7 @@ func RequestEpay(c *gin.Context) {
|
||||
payType = "wxpay"
|
||||
}
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
@@ -236,7 +236,7 @@ func RequestAmount(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
group, err := model.CacheGetUserGroup(id)
|
||||
group, err := model.GetUserGroup(id, true)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
||||
return
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -454,7 +455,15 @@ func GetUserModels(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
models := model.GetGroupModels(user.Group)
|
||||
groups := setting.GetUserUsableGroups(user.Group)
|
||||
var models []string
|
||||
for group := range groups {
|
||||
for _, g := range model.GetGroupModels(group) {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
8
dto/file_data.go
Normal file
8
dto/file_data.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package dto
|
||||
|
||||
type LocalFileData struct {
|
||||
MimeType string
|
||||
Base64Data string
|
||||
Url string
|
||||
Size int64
|
||||
}
|
||||
@@ -3,39 +3,48 @@ package dto
|
||||
import "encoding/json"
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
|
||||
}
|
||||
|
||||
type FormatJsonSchema struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Schema any `json:"schema,omitempty"`
|
||||
Strict any `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat any `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat any `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools []ToolCall `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Modalities any `json:"modalities,omitempty"`
|
||||
Audio any `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAITools struct {
|
||||
@@ -80,11 +89,11 @@ type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCalls any `json:"tool_calls,omitempty"`
|
||||
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
type MediaContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
@@ -107,7 +116,23 @@ const (
|
||||
ContentTypeInputAudio = "input_audio"
|
||||
)
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
func (m *Message) ParseToolCalls() []ToolCall {
|
||||
if m.ToolCalls == nil {
|
||||
return nil
|
||||
}
|
||||
var toolCalls []ToolCall
|
||||
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
|
||||
return toolCalls
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func (m *Message) SetToolCalls(toolCalls any) {
|
||||
toolCallsJson, _ := json.Marshal(toolCalls)
|
||||
m.ToolCalls = toolCallsJson
|
||||
}
|
||||
|
||||
func (m *Message) StringContent() string {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return stringContent
|
||||
@@ -120,7 +145,7 @@ func (m *Message) SetStringContent(content string) {
|
||||
m.Content = jsonContent
|
||||
}
|
||||
|
||||
func (m Message) IsStringContent() bool {
|
||||
func (m *Message) IsStringContent() bool {
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
return true
|
||||
@@ -128,11 +153,11 @@ func (m Message) IsStringContent() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
func (m *Message) ParseContent() []MediaContent {
|
||||
var contentList []MediaContent
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
@@ -148,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
@@ -161,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
} else {
|
||||
subObj["detail"] = "high"
|
||||
}
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
@@ -169,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
},
|
||||
})
|
||||
} else if url, ok := contentMap["image_url"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: url,
|
||||
@@ -179,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage {
|
||||
}
|
||||
case ContentTypeInputAudio:
|
||||
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
contentList = append(contentList, MediaContent{
|
||||
Type: ContentTypeInputAudio,
|
||||
InputAudio: MessageInputAudio{
|
||||
Data: subObj["data"].(string),
|
||||
|
||||
@@ -86,6 +86,10 @@ type ToolCall struct {
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
func (c *ToolCall) SetIndex(i int) {
|
||||
c.Index = &i
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
|
||||
7
main.go
7
main.go
@@ -33,9 +33,11 @@ var indexPage []byte
|
||||
func main() {
|
||||
err := godotenv.Load(".env")
|
||||
if err != nil {
|
||||
common.SysError("failed to load .env file: " + err.Error())
|
||||
common.SysLog("Support for .env file is disabled")
|
||||
}
|
||||
|
||||
common.LoadEnv()
|
||||
|
||||
common.SetupLogger()
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
@@ -80,9 +82,6 @@ func main() {
|
||||
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
|
||||
model.InitChannelCache()
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
go model.SyncTokenCache(common.SyncFrequency)
|
||||
}
|
||||
if common.MemoryCacheEnabled {
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
go model.SyncChannelCache(common.SyncFrequency)
|
||||
|
||||
@@ -201,7 +201,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
userEnabled, err := model.IsUserEnabled(token.UserId, false)
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"one-api/model"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -38,16 +40,16 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
|
||||
return
|
||||
}
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
userGroup, _ := model.GetUserGroup(userId, false)
|
||||
tokenGroup := c.GetString("token_group")
|
||||
if tokenGroup != "" {
|
||||
// check common.UserUsableGroups[userGroup]
|
||||
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if _, ok := common.GroupRatio[tokenGroup]; !ok {
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
@@ -112,6 +114,7 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -3,10 +3,11 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
@@ -173,18 +174,67 @@ func (channel *Channel) DeleteAbilities() error {
|
||||
|
||||
// UpdateAbilities updates abilities of this channel.
|
||||
// Make sure the channel is completed before calling this function.
|
||||
func (channel *Channel) UpdateAbilities() error {
|
||||
// A quick and dirty way to update abilities
|
||||
func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
isNewTx := false
|
||||
// 如果没有传入事务,创建新的事务
|
||||
if tx == nil {
|
||||
tx = DB.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
isNewTx = true
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// First delete all abilities of this channel
|
||||
err := channel.DeleteAbilities()
|
||||
err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
if isNewTx {
|
||||
tx.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Then add new abilities
|
||||
err = channel.AddAbilities()
|
||||
if err != nil {
|
||||
return err
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
ChannelId: channel.Id,
|
||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||
Priority: channel.Priority,
|
||||
Weight: uint(channel.GetWeight()),
|
||||
Tag: channel.Tag,
|
||||
}
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
}
|
||||
|
||||
if len(abilities) > 0 {
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err = tx.Create(&chunk).Error
|
||||
if err != nil {
|
||||
if isNewTx {
|
||||
tx.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是新创建的事务,需要提交
|
||||
if isNewTx {
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -246,7 +296,7 @@ func FixAbility() (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
for _, channel := range channels {
|
||||
err := channel.UpdateAbilities()
|
||||
err := channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
|
||||
} else {
|
||||
|
||||
308
model/cache.go
308
model/cache.go
@@ -1,209 +1,115 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
|
||||
// 仅用于定时同步缓存
|
||||
var token2UserId = make(map[string]int)
|
||||
var token2UserIdLock sync.RWMutex
|
||||
|
||||
func cacheSetToken(token *Token) error {
|
||||
jsonBytes, err := json.Marshal(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
|
||||
return err
|
||||
}
|
||||
token2UserIdLock.Lock()
|
||||
defer token2UserIdLock.Unlock()
|
||||
token2UserId[token.Key] = token.UserId
|
||||
return nil
|
||||
}
|
||||
|
||||
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetTokenByKey(key)
|
||||
}
|
||||
var token *Token
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
// 如果缓存中不存在,则从数据库中获取
|
||||
token, err = GetTokenByKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = cacheSetToken(token)
|
||||
return token, nil
|
||||
}
|
||||
// 如果缓存中存在,则续期时间
|
||||
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
|
||||
err = json.Unmarshal([]byte(tokenObjectString), &token)
|
||||
return token, err
|
||||
}
|
||||
|
||||
func SyncTokenCache(frequency int) {
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Second)
|
||||
common.SysLog("syncing tokens from database")
|
||||
token2UserIdLock.Lock()
|
||||
// 从token2UserId中获取所有的key
|
||||
var copyToken2UserId = make(map[string]int)
|
||||
for s, i := range token2UserId {
|
||||
copyToken2UserId[s] = i
|
||||
}
|
||||
token2UserId = make(map[string]int)
|
||||
token2UserIdLock.Unlock()
|
||||
|
||||
for key := range copyToken2UserId {
|
||||
token, err := GetTokenByKey(key)
|
||||
if err != nil {
|
||||
// 如果数据库中不存在,则删除缓存
|
||||
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
|
||||
//delete redis
|
||||
err := common.RedisDel(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
|
||||
}
|
||||
} else {
|
||||
// 如果数据库中存在,先检查redis
|
||||
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
// 如果redis中不存在,则跳过
|
||||
continue
|
||||
}
|
||||
err = cacheSetToken(token)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetUserGroup(id int) (group string, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUserGroup(id)
|
||||
}
|
||||
group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
|
||||
if err != nil {
|
||||
group, err = GetUserGroup(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return group, err
|
||||
}
|
||||
|
||||
func CacheGetUsername(id int) (username string, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUsernameById(id)
|
||||
}
|
||||
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
|
||||
if err != nil {
|
||||
username, err = GetUsernameById(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user group error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return username, err
|
||||
}
|
||||
|
||||
func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetUserQuota(id)
|
||||
}
|
||||
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||
if err != nil {
|
||||
quota, err = GetUserQuota(id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user quota error: " + err.Error())
|
||||
}
|
||||
return quota, err
|
||||
}
|
||||
quota, err = strconv.Atoi(quotaString)
|
||||
return quota, err
|
||||
}
|
||||
|
||||
func CacheUpdateUserQuota(id int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
quota, err := GetUserQuota(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cacheSetUserQuota(id, quota)
|
||||
}
|
||||
|
||||
func cacheSetUserQuota(id int, quota int) error {
|
||||
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheDecreaseUserQuota(id int, quota int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
|
||||
return err
|
||||
}
|
||||
|
||||
func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
if !common.RedisEnabled {
|
||||
return IsUserEnabled(userId)
|
||||
}
|
||||
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||
if err == nil {
|
||||
return enabled == "1", nil
|
||||
}
|
||||
|
||||
userEnabled, err := IsUserEnabled(userId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
enabled = "0"
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
common.SysError("Redis set user enabled error: " + err.Error())
|
||||
}
|
||||
return userEnabled, err
|
||||
}
|
||||
//func CacheGetUserGroup(id int) (group string, err error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return GetUserGroup(id)
|
||||
// }
|
||||
// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
|
||||
// if err != nil {
|
||||
// group, err = GetUserGroup(id)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
|
||||
// if err != nil {
|
||||
// common.SysError("Redis set user group error: " + err.Error())
|
||||
// }
|
||||
// }
|
||||
// return group, err
|
||||
//}
|
||||
//
|
||||
//func CacheGetUsername(id int) (username string, err error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return GetUsernameById(id)
|
||||
// }
|
||||
// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
|
||||
// if err != nil {
|
||||
// username, err = GetUsernameById(id)
|
||||
// if err != nil {
|
||||
// return "", err
|
||||
// }
|
||||
// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
|
||||
// if err != nil {
|
||||
// common.SysError("Redis set user group error: " + err.Error())
|
||||
// }
|
||||
// }
|
||||
// return username, err
|
||||
//}
|
||||
//
|
||||
//func CacheGetUserQuota(id int) (quota int, err error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return GetUserQuota(id)
|
||||
// }
|
||||
// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||
// if err != nil {
|
||||
// quota, err = GetUserQuota(id)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// return quota, nil
|
||||
// }
|
||||
// quota, err = strconv.Atoi(quotaString)
|
||||
// return quota, nil
|
||||
//}
|
||||
//
|
||||
//func CacheUpdateUserQuota(id int) error {
|
||||
// if !common.RedisEnabled {
|
||||
// return nil
|
||||
// }
|
||||
// quota, err := GetUserQuota(id)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return cacheSetUserQuota(id, quota)
|
||||
//}
|
||||
//
|
||||
//func cacheSetUserQuota(id int, quota int) error {
|
||||
// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//func CacheDecreaseUserQuota(id int, quota int) error {
|
||||
// if !common.RedisEnabled {
|
||||
// return nil
|
||||
// }
|
||||
// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
// if !common.RedisEnabled {
|
||||
// return IsUserEnabled(userId)
|
||||
// }
|
||||
// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||
// if err == nil {
|
||||
// return enabled == "1", nil
|
||||
// }
|
||||
//
|
||||
// userEnabled, err := IsUserEnabled(userId)
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
// enabled = "0"
|
||||
// if userEnabled {
|
||||
// enabled = "1"
|
||||
// }
|
||||
// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
|
||||
// if err != nil {
|
||||
// common.SysError("Redis set user enabled error: " + err.Error())
|
||||
// }
|
||||
// return userEnabled, err
|
||||
//}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
@@ -344,12 +250,12 @@ func CacheGetChannel(id int) (*Channel, error) {
|
||||
}
|
||||
|
||||
func CacheUpdateChannelStatus(id int, status int) {
|
||||
if (!common.MemoryCacheEnabled) {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
if !common.MemoryCacheEnabled {
|
||||
return
|
||||
}
|
||||
channelSyncLock.Lock()
|
||||
defer channelSyncLock.Unlock()
|
||||
if channel, ok := channelsIDM[id]; ok {
|
||||
channel.Status = status
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +257,7 @@ func (channel *Channel) Update() error {
|
||||
return err
|
||||
}
|
||||
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
||||
err = channel.UpdateAbilities()
|
||||
err = channel.UpdateAbilities(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -389,7 +389,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
|
||||
channels, err := GetChannelsByTag(updatedTag, false)
|
||||
if err == nil {
|
||||
for _, channel := range channels {
|
||||
err = channel.UpdateAbilities()
|
||||
err = channel.UpdateAbilities(nil)
|
||||
if err != nil {
|
||||
common.SysError("failed to update abilities: " + err.Error())
|
||||
}
|
||||
@@ -509,3 +509,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
|
||||
}
|
||||
channel.Setting = string(settingBytes)
|
||||
}
|
||||
|
||||
func GetChannelsByIds(ids []int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
err := DB.Where("id in (?)", ids).Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func BatchSetChannelTag(ids []int, tag *string) error {
|
||||
// 开启事务
|
||||
tx := DB.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// 更新标签
|
||||
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// update ability status
|
||||
channels, err := GetChannelsByIds(ids)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
err = channel.UpdateAbilities(tx)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
55
model/log.go
55
model/log.go
@@ -28,6 +28,7 @@ type Log struct {
|
||||
IsStream bool `json:"is_stream" gorm:"default:false"`
|
||||
ChannelId int `json:"channel" gorm:"index"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
@@ -39,6 +40,19 @@ const (
|
||||
LogTypeSystem
|
||||
)
|
||||
|
||||
func formatUserLogs(logs []*Log) {
|
||||
for i := range logs {
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
}
|
||||
logs[i].Other = common.MapToJsonStr(otherMap)
|
||||
logs[i].Id = logs[i].Id % 1024
|
||||
}
|
||||
}
|
||||
|
||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
var tk Token
|
||||
@@ -49,6 +63,7 @@ func GetLogByKey(key string) (logs []*Log, err error) {
|
||||
} else {
|
||||
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
|
||||
}
|
||||
formatUserLogs(logs)
|
||||
return logs, err
|
||||
}
|
||||
|
||||
@@ -56,7 +71,7 @@ func RecordLog(userId int, logType int, content string) {
|
||||
if logType == LogTypeConsume && !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := CacheGetUsername(userId)
|
||||
username, _ := GetUsernameById(userId, false)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -70,12 +85,14 @@ func RecordLog(userId int, logType int, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
|
||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
|
||||
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
|
||||
isStream bool, group string, other map[string]interface{}) {
|
||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username, _ := CacheGetUsername(userId)
|
||||
username, _ := GetUsernameById(userId, false)
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
@@ -92,6 +109,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
TokenId: tokenId,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
@@ -105,7 +123,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB
|
||||
@@ -130,6 +148,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -141,7 +162,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
|
||||
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
|
||||
var tx *gorm.DB
|
||||
if logType == LogTypeUnknown {
|
||||
tx = LOG_DB.Where("user_id = ?", userId)
|
||||
@@ -160,20 +181,15 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
if endTimestamp != 0 {
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
|
||||
for i := range logs {
|
||||
var otherMap map[string]interface{}
|
||||
otherMap = common.StrToMap(logs[i].Other)
|
||||
if otherMap != nil {
|
||||
// delete admin
|
||||
delete(otherMap, "admin_info")
|
||||
}
|
||||
logs[i].Other = common.MapToJsonStr(otherMap)
|
||||
}
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
formatUserLogs(logs)
|
||||
return logs, total, err
|
||||
}
|
||||
|
||||
@@ -183,7 +199,8 @@ func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
||||
}
|
||||
|
||||
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
||||
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
|
||||
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
||||
formatUserLogs(logs)
|
||||
return logs, err
|
||||
}
|
||||
|
||||
@@ -193,7 +210,7 @@ type Stat struct {
|
||||
Tpm int `json:"tpm"`
|
||||
}
|
||||
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
|
||||
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
||||
|
||||
// 为rpm和tpm创建单独的查询
|
||||
@@ -221,6 +238,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
||||
}
|
||||
|
||||
tx = tx.Where("type = ?", LogTypeConsume)
|
||||
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
|
||||
|
||||
@@ -13,6 +13,20 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
var keyCol string
|
||||
|
||||
func init() {
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
keyCol = `"key"`
|
||||
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
keyCol = "`key`"
|
||||
}
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
@@ -2,7 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -61,16 +61,16 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
|
||||
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = constant.Chats2JsonString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -87,8 +87,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
@@ -98,17 +98,17 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
|
||||
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
|
||||
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
|
||||
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
|
||||
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -209,23 +209,23 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DefaultCollapseSidebar":
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
setting.MjNotifyEnabled = boolValue
|
||||
case "MjAccountFilterEnabled":
|
||||
constant.MjAccountFilterEnabled = boolValue
|
||||
setting.MjAccountFilterEnabled = boolValue
|
||||
case "MjModeClearEnabled":
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
setting.MjModeClearEnabled = boolValue
|
||||
case "MjForwardUrlEnabled":
|
||||
constant.MjForwardUrlEnabled = boolValue
|
||||
setting.MjForwardUrlEnabled = boolValue
|
||||
case "MjActionCheckSuccessEnabled":
|
||||
constant.MjActionCheckSuccessEnabled = boolValue
|
||||
setting.MjActionCheckSuccessEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
setting.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
constant.CheckSensitiveOnPromptEnabled = boolValue
|
||||
setting.CheckSensitiveOnPromptEnabled = boolValue
|
||||
//case "CheckSensitiveOnCompletionEnabled":
|
||||
// constant.CheckSensitiveOnCompletionEnabled = boolValue
|
||||
case "StopOnSensitiveEnabled":
|
||||
constant.StopOnSensitiveEnabled = boolValue
|
||||
setting.StopOnSensitiveEnabled = boolValue
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
}
|
||||
@@ -245,25 +245,25 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
constant.ServerAddress = value
|
||||
setting.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
constant.WorkerUrl = value
|
||||
setting.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
constant.WorkerValidKey = value
|
||||
setting.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
constant.PayAddress = value
|
||||
setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = constant.UpdateChatsByJsonString(value)
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
constant.CustomCallbackAddress = value
|
||||
setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
constant.EpayId = value
|
||||
setting.EpayId = value
|
||||
case "EpayKey":
|
||||
constant.EpayKey = value
|
||||
setting.EpayKey = value
|
||||
case "Price":
|
||||
constant.Price, _ = strconv.ParseFloat(value, 64)
|
||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
constant.MinTopUp, _ = strconv.Atoi(value)
|
||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "TopupGroupRatio":
|
||||
err = common.UpdateTopupGroupRatioByJSONString(value)
|
||||
case "GitHubClientId":
|
||||
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "ModelRatio":
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = common.UpdateGroupRatioByJSONString(value)
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = common.UpdateUserUsableGroupsByJSONString(value)
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = common.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
@@ -331,9 +331,9 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "QuotaPerUnit":
|
||||
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
||||
case "SensitiveWords":
|
||||
constant.SensitiveWordsFromString(value)
|
||||
setting.SensitiveWordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
149
model/token.go
149
model/token.go
@@ -3,10 +3,11 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -30,6 +31,10 @@ type Token struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
func (token *Token) Clean() {
|
||||
token.Key = ""
|
||||
}
|
||||
|
||||
func (token *Token) GetIpLimitsMap() map[string]any {
|
||||
// delete empty spaces
|
||||
//split with \n
|
||||
@@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
if key == "" {
|
||||
return nil, errors.New("未提供令牌")
|
||||
}
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
token, err = GetTokenByKey(key, false)
|
||||
if err == nil {
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
keyPrefix := key[:3]
|
||||
@@ -128,22 +133,38 @@ func GetTokenById(id int) (*Token, error) {
|
||||
token := Token{Id: id}
|
||||
var err error = nil
|
||||
err = DB.First(&token, "id = ?", id).Error
|
||||
if err != nil {
|
||||
if common.RedisEnabled {
|
||||
go cacheSetToken(&token)
|
||||
}
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
if err := cacheSetToken(token); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func GetTokenByKey(key string) (*Token, error) {
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) && token != nil {
|
||||
gopool.Go(func() {
|
||||
if err := cacheSetToken(*token); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
// Try Redis first
|
||||
token, err := cacheGetTokenByKey(key)
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
var token Token
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return &token, err
|
||||
fromDB = true
|
||||
err = DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return token, err
|
||||
}
|
||||
|
||||
func (token *Token) Insert() error {
|
||||
@@ -153,20 +174,48 @@ func (token *Token) Insert() error {
|
||||
}
|
||||
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
var err error
|
||||
func (token *Token) Update() (err error) {
|
||||
defer func() {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
err := cacheSetToken(*token)
|
||||
if err != nil {
|
||||
common.SysError("failed to update token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
|
||||
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (token *Token) SelectUpdate() error {
|
||||
func (token *Token) SelectUpdate() (err error) {
|
||||
defer func() {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
err := cacheSetToken(*token)
|
||||
if err != nil {
|
||||
common.SysError("failed to update token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
// This can update zero values
|
||||
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
|
||||
}
|
||||
|
||||
func (token *Token) Delete() error {
|
||||
var err error
|
||||
func (token *Token) Delete() (err error) {
|
||||
defer func() {
|
||||
if shouldUpdateRedis(true, err) {
|
||||
gopool.Go(func() {
|
||||
err := cacheDeleteToken(token.Key)
|
||||
if err != nil {
|
||||
common.SysError("failed to delete token cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
err = DB.Delete(token).Error
|
||||
return err
|
||||
}
|
||||
@@ -214,10 +263,18 @@ func DeleteTokenById(id int, userId int) (err error) {
|
||||
return token.Delete()
|
||||
}
|
||||
|
||||
func IncreaseTokenQuota(id int, quota int) (err error) {
|
||||
func IncreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
err := cacheIncrTokenQuota(key, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to increase token quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
|
||||
return nil
|
||||
@@ -236,10 +293,18 @@ func increaseTokenQuota(id int, quota int) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func DecreaseTokenQuota(id int, quota int) (err error) {
|
||||
func DecreaseTokenQuota(id int, key string, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
err := cacheDecrTokenQuota(key, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to decrease token quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
|
||||
return nil
|
||||
@@ -258,37 +323,31 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
if quota < 0 {
|
||||
return 0, errors.New("quota 不能为负数!")
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
if !relayInfo.IsPlayground {
|
||||
token, err := GetTokenById(relayInfo.TokenId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return 0, errors.New("令牌额度不足")
|
||||
}
|
||||
if relayInfo.IsPlayground {
|
||||
return nil
|
||||
}
|
||||
userQuota, err = GetUserQuota(relayInfo.UserId)
|
||||
//if relayInfo.TokenUnlimited {
|
||||
// return nil
|
||||
//}
|
||||
token, err := GetTokenById(relayInfo.TokenId)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
if userQuota < quota {
|
||||
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
|
||||
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
||||
return errors.New("令牌额度不足")
|
||||
}
|
||||
if !relayInfo.IsPlayground {
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
||||
return userQuota - quota, err
|
||||
return nil
|
||||
}
|
||||
|
||||
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
|
||||
|
||||
if quota > 0 {
|
||||
err = DecreaseUserQuota(relayInfo.UserId, quota)
|
||||
@@ -301,9 +360,9 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
|
||||
|
||||
if !relayInfo.IsPlayground {
|
||||
if quota > 0 {
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
|
||||
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||
} else {
|
||||
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
|
||||
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -325,7 +384,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
|
||||
prompt = "您的额度已用尽"
|
||||
}
|
||||
if email != "" {
|
||||
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
|
||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||
err = common.SendEmail(prompt, email,
|
||||
fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
|
||||
if err != nil {
|
||||
|
||||
64
model/token_cache.go
Normal file
64
model/token_cache.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"time"
|
||||
)
|
||||
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheDeleteToken(key string) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheIncrTokenQuota(key string, increment int64) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cacheDecrTokenQuota(key string, decrement int64) error {
|
||||
return cacheIncrTokenQuota(key, -decrement)
|
||||
}
|
||||
|
||||
func cacheSetTokenField(key string, field string, value string) error {
|
||||
key = common.GenerateHMAC(key)
|
||||
err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CacheGetTokenByKey 从缓存中获取 token,如果缓存中不存在,则从数据库中获取
|
||||
func cacheGetTokenByKey(key string) (*Token, error) {
|
||||
hmacKey := common.GenerateHMAC(key)
|
||||
if !common.RedisEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
var token Token
|
||||
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.Key = key
|
||||
return &token, nil
|
||||
}
|
||||
192
model/user.go
192
model/user.go
@@ -6,7 +6,8 @@ import (
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -93,9 +94,9 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
keywordInt, err := strconv.Atoi(keyword)
|
||||
if err == nil {
|
||||
// 如果转换成功,按照ID和可选的组别搜索用户
|
||||
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
|
||||
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
|
||||
if group != "" {
|
||||
query = query.Where("`group` = ?", group) // 使用反引号包围group
|
||||
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
|
||||
}
|
||||
err = query.Find(&users).Error
|
||||
if err != nil || len(users) > 0 {
|
||||
@@ -106,9 +107,9 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||
err = nil
|
||||
|
||||
query := DB.Unscoped().Omit("password")
|
||||
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
|
||||
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
@@ -246,14 +247,12 @@ func (user *User) Update(updatePassword bool) error {
|
||||
}
|
||||
newUser := *user
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Updates(newUser).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
}
|
||||
if err = DB.Model(user).Updates(newUser).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
|
||||
// 更新缓存
|
||||
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
||||
}
|
||||
|
||||
func (user *User) Edit(updatePassword bool) error {
|
||||
@@ -264,6 +263,7 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
newUser := *user
|
||||
updates := map[string]interface{}{
|
||||
"username": newUser.Username,
|
||||
@@ -274,23 +274,26 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
if updatePassword {
|
||||
updates["password"] = newUser.Password
|
||||
}
|
||||
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Updates(updates).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||
}
|
||||
if err = DB.Model(user).Updates(updates).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
|
||||
// 更新缓存
|
||||
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
|
||||
}
|
||||
|
||||
func (user *User) Delete() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Delete(user).Error
|
||||
return err
|
||||
if err := DB.Delete(user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除缓存
|
||||
return invalidateUserCache(user.Id)
|
||||
}
|
||||
|
||||
func (user *User) HardDelete() error {
|
||||
@@ -404,15 +407,33 @@ func IsAdmin(userId int) bool {
|
||||
return user.Role >= common.RoleAdminUser
|
||||
}
|
||||
|
||||
func IsUserEnabled(userId int) (bool, error) {
|
||||
if userId == 0 {
|
||||
return false, errors.New("user id is empty")
|
||||
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
|
||||
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserStatusCache(id, status); err != nil {
|
||||
common.SysError("failed to update user status cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
// Try Redis first
|
||||
status, err := getUserStatusCache(id)
|
||||
if err == nil {
|
||||
return status == common.UserStatusEnabled, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
var user User
|
||||
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
|
||||
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return user.Status == common.UserStatusEnabled, nil
|
||||
}
|
||||
|
||||
@@ -428,14 +449,33 @@ func ValidateAccessToken(token string) (user *User) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUserQuota(id int) (quota int, err error) {
|
||||
// GetUserQuota gets quota from Redis first, falls back to DB if needed
|
||||
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserQuotaCache(id, quota); err != nil {
|
||||
common.SysError("failed to update user quota cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
quota, err := getUserQuotaCache(id)
|
||||
if err == nil {
|
||||
return quota, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
//common.SysError("failed to get user quota from cache: " + err.Error())
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||
if err != nil {
|
||||
if common.RedisEnabled {
|
||||
go cacheSetUserQuota(id, quota)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return quota, err
|
||||
|
||||
return quota, nil
|
||||
}
|
||||
|
||||
func GetUserUsedQuota(id int) (quota int, err error) {
|
||||
@@ -448,20 +488,44 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
return email, err
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
// GetUserGroup gets group from Redis first, falls back to DB if needed
|
||||
func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserGroupCache(id, group); err != nil {
|
||||
common.SysError("failed to update user group cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
group, err := getUserGroupCache(id)
|
||||
if err == nil {
|
||||
return group, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
return group, err
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func IncreaseUserQuota(id int, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
gopool.Go(func() {
|
||||
err := cacheIncrUserQuota(id, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to increase user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
|
||||
return nil
|
||||
@@ -471,6 +535,9 @@ func IncreaseUserQuota(id int, quota int) (err error) {
|
||||
|
||||
func increaseUserQuota(id int, quota int) (err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -478,6 +545,12 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
||||
if quota < 0 {
|
||||
return errors.New("quota 不能为负数!")
|
||||
}
|
||||
gopool.Go(func() {
|
||||
err := cacheDecrUserQuota(id, int64(quota))
|
||||
if err != nil {
|
||||
common.SysError("failed to decrease user quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
|
||||
return nil
|
||||
@@ -487,9 +560,23 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
||||
|
||||
func decreaseUserQuota(id int, quota int) (err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func DeltaUpdateUserQuota(id int, delta int) (err error) {
|
||||
if delta == 0 {
|
||||
return nil
|
||||
}
|
||||
if delta > 0 {
|
||||
return IncreaseUserQuota(id, delta)
|
||||
} else {
|
||||
return DecreaseUserQuota(id, -delta)
|
||||
}
|
||||
}
|
||||
|
||||
func GetRootUserEmail() (email string) {
|
||||
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
||||
return email
|
||||
@@ -513,7 +600,13 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota and request count: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
//// 更新缓存
|
||||
//if err := invalidateUserCache(id); err != nil {
|
||||
// common.SysError("failed to invalidate user cache: " + err.Error())
|
||||
//}
|
||||
}
|
||||
|
||||
func updateUserUsedQuota(id int, quota int) {
|
||||
@@ -534,9 +627,32 @@ func updateUserRequestCount(id int, count int) {
|
||||
}
|
||||
}
|
||||
|
||||
func GetUsernameById(id int) (username string, err error) {
|
||||
// GetUsernameById gets username from Redis first, falls back to DB if needed
|
||||
func GetUsernameById(id int, fromDB bool) (username string, err error) {
|
||||
defer func() {
|
||||
// Update Redis cache asynchronously on successful DB read
|
||||
if shouldUpdateRedis(fromDB, err) {
|
||||
gopool.Go(func() {
|
||||
if err := updateUserNameCache(id, username); err != nil {
|
||||
common.SysError("failed to update user name cache: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}()
|
||||
if !fromDB && common.RedisEnabled {
|
||||
username, err := getUserNameCache(id)
|
||||
if err == nil {
|
||||
return username, nil
|
||||
}
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
|
||||
return username, err
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {
|
||||
|
||||
206
model/user_cache.go
Normal file
206
model/user_cache.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Change UserCache struct to userCache
|
||||
type userCache struct {
|
||||
Id int `json:"id"`
|
||||
Group string `json:"group"`
|
||||
Quota int `json:"quota"`
|
||||
Status int `json:"status"`
|
||||
Role int `json:"role"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
// Rename all exported functions to private ones
|
||||
// invalidateUserCache clears all user related cache
|
||||
func invalidateUserCache(userId int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := []string{
|
||||
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
||||
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
||||
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
||||
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if err := common.RedisDel(key); err != nil {
|
||||
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateUserGroupCache updates user group cache
|
||||
func updateUserGroupCache(userId int, group string) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
|
||||
group,
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserQuotaCache updates user quota cache
|
||||
func updateUserQuotaCache(userId int, quota int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
|
||||
fmt.Sprintf("%d", quota),
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserStatusCache updates user status cache
|
||||
func updateUserStatusCache(userId int, userEnabled bool) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
enabled := "0"
|
||||
if userEnabled {
|
||||
enabled = "1"
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
|
||||
enabled,
|
||||
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserNameCache updates username cache
|
||||
func updateUserNameCache(userId int, username string) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
return common.RedisSet(
|
||||
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
|
||||
username,
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
// updateUserCache updates all user cache fields
|
||||
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := updateUserGroupCache(userId, userGroup); err != nil {
|
||||
return fmt.Errorf("update group cache: %w", err)
|
||||
}
|
||||
|
||||
if err := updateUserQuotaCache(userId, quota); err != nil {
|
||||
return fmt.Errorf("update quota cache: %w", err)
|
||||
}
|
||||
|
||||
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
|
||||
return fmt.Errorf("update status cache: %w", err)
|
||||
}
|
||||
|
||||
if err := updateUserNameCache(userId, username); err != nil {
|
||||
return fmt.Errorf("update username cache: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getUserGroupCache gets user group from cache
|
||||
func getUserGroupCache(userId int) (string, error) {
|
||||
if !common.RedisEnabled {
|
||||
return "", nil
|
||||
}
|
||||
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
|
||||
}
|
||||
|
||||
// getUserQuotaCache gets user quota from cache
|
||||
func getUserQuotaCache(userId int) (int, error) {
|
||||
if !common.RedisEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(quotaStr)
|
||||
}
|
||||
|
||||
// getUserStatusCache gets user status from cache
|
||||
func getUserStatusCache(userId int) (int, error) {
|
||||
if !common.RedisEnabled {
|
||||
return 0, nil
|
||||
}
|
||||
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(statusStr)
|
||||
}
|
||||
|
||||
// getUserNameCache gets username from cache
|
||||
func getUserNameCache(userId int) (string, error) {
|
||||
if !common.RedisEnabled {
|
||||
return "", nil
|
||||
}
|
||||
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
|
||||
}
|
||||
|
||||
// getUserCache gets complete user cache
|
||||
func getUserCache(userId int) (*userCache, error) {
|
||||
if !common.RedisEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
group, err := getUserGroupCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group cache: %w", err)
|
||||
}
|
||||
|
||||
quota, err := getUserQuotaCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get quota cache: %w", err)
|
||||
}
|
||||
|
||||
status, err := getUserStatusCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get status cache: %w", err)
|
||||
}
|
||||
|
||||
username, err := getUserNameCache(userId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get username cache: %w", err)
|
||||
}
|
||||
|
||||
return &userCache{
|
||||
Id: userId,
|
||||
Group: group,
|
||||
Quota: quota,
|
||||
Status: status,
|
||||
Username: username,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add atomic quota operations
|
||||
func cacheIncrUserQuota(userId int, delta int64) error {
|
||||
if !common.RedisEnabled {
|
||||
return nil
|
||||
}
|
||||
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
|
||||
return common.RedisIncr(key, delta)
|
||||
}
|
||||
|
||||
func cacheDecrUserQuota(userId int, delta int64) error {
|
||||
return cacheIncrUserQuota(userId, -delta)
|
||||
}
|
||||
@@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) {
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func shouldUpdateRedis(fromDB bool, err error) bool {
|
||||
return common.RedisEnabled && fromDB && err == nil
|
||||
}
|
||||
|
||||
@@ -225,9 +225,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
|
||||
claudeMediaMessage.Source.MediaType = mimeType
|
||||
claudeMediaMessage.Source.Data = data
|
||||
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
claudeMediaMessage.Source.MediaType = fileData.MimeType
|
||||
claudeMediaMessage.Source.Data = fileData.Base64Data
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
|
||||
if err != nil {
|
||||
@@ -240,14 +243,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
||||
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
|
||||
}
|
||||
if message.ToolCalls != nil {
|
||||
for _, tc := range message.ToolCalls.([]interface{}) {
|
||||
toolCallJSON, _ := json.Marshal(tc)
|
||||
var toolCall dto.ToolCall
|
||||
err := json.Unmarshal(toolCallJSON, &toolCall)
|
||||
if err != nil {
|
||||
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
|
||||
continue
|
||||
}
|
||||
for _, toolCall := range message.ParseToolCalls() {
|
||||
inputObj := make(map[string]any)
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
|
||||
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
|
||||
@@ -393,7 +389,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
|
||||
}
|
||||
choice.SetStringContent(responseText)
|
||||
if len(tools) > 0 {
|
||||
choice.Message.ToolCalls = tools
|
||||
choice.Message.SetToolCalls(tools)
|
||||
}
|
||||
fullTextResponse.Model = claudeResponse.Model
|
||||
choices = append(choices, choice)
|
||||
|
||||
@@ -57,7 +57,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return CovertGemini2OpenAI(*request), nil
|
||||
ai, err := CovertGemini2OpenAI(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ai, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
@@ -72,7 +76,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
err, usage = GeminiChatStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = GeminiChatHandler(c, resp)
|
||||
err, usage = GeminiChatHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package gemini
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
// stable version
|
||||
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
|
||||
|
||||
@@ -4,7 +4,7 @@ type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
Tools []GeminiChatTool `json:"tools,omitempty"`
|
||||
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
|
||||
}
|
||||
|
||||
@@ -18,10 +18,39 @@ type FunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
type GeminiFunctionResponseContent struct {
|
||||
Name string `json:"name"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response GeminiFunctionResponseContent `json:"response"`
|
||||
}
|
||||
|
||||
type GeminiPartExecutableCode struct {
|
||||
Language string `json:"language,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiPartCodeExecutionResult struct {
|
||||
Outcome string `json:"outcome,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiFileData struct {
|
||||
MimeType string `json:"mimeType,omitempty"`
|
||||
FileUri string `json:"fileUri,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
@@ -34,23 +63,28 @@ type GeminiChatSafetySettings struct {
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
GoogleSearch any `json:"googleSearch,omitempty"`
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
type GeminiChatTool struct {
|
||||
GoogleSearch any `json:"googleSearch,omitempty"`
|
||||
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
|
||||
CodeExecution any `json:"codeExecution,omitempty"`
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
FinishReason *string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
@@ -12,12 +12,14 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
@@ -46,147 +48,320 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
Seed: int64(textRequest.Seed),
|
||||
},
|
||||
}
|
||||
|
||||
// openaiContent.FuncToToolCalls()
|
||||
if textRequest.Tools != nil {
|
||||
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
|
||||
googleSearch := false
|
||||
codeExecution := false
|
||||
for _, tool := range textRequest.Tools {
|
||||
if tool.Function.Name == "googleSearch" {
|
||||
googleSearch = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Name == "codeExecution" {
|
||||
codeExecution = true
|
||||
continue
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
params, ok := tool.Function.Parameters.(map[string]interface{})
|
||||
if ok {
|
||||
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
|
||||
if len(props) == 0 {
|
||||
tool.Function.Parameters = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
functions = append(functions, tool.Function)
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
{
|
||||
FunctionDeclarations: functions,
|
||||
},
|
||||
}
|
||||
if codeExecution {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
CodeExecution: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if googleSearch {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTools{
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
GoogleSearch: make(map[string]string),
|
||||
})
|
||||
}
|
||||
if len(functions) > 0 {
|
||||
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
|
||||
FunctionDeclarations: functions,
|
||||
})
|
||||
}
|
||||
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
|
||||
// json_data, _ := json.Marshal(geminiRequest.Tools)
|
||||
// common.SysLog("tools_json: " + string(json_data))
|
||||
} else if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
geminiRequest.Tools = []GeminiChatTool{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
|
||||
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
|
||||
|
||||
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
|
||||
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
|
||||
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
|
||||
}
|
||||
}
|
||||
tool_call_ids := make(map[string]string)
|
||||
var system_content []string
|
||||
//shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
|
||||
if message.Role == "system" {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
system_content = append(system_content, message.StringContent())
|
||||
continue
|
||||
} else if message.Role == "tool" || message.Role == "function" {
|
||||
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "user",
|
||||
})
|
||||
}
|
||||
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
|
||||
name := ""
|
||||
if message.Name != nil {
|
||||
name = *message.Name
|
||||
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
|
||||
name = val
|
||||
}
|
||||
content := common.StrToMap(message.StringContent())
|
||||
functionResp := &FunctionResponse{
|
||||
Name: name,
|
||||
Response: GeminiFunctionResponseContent{
|
||||
Name: name,
|
||||
Content: content,
|
||||
},
|
||||
}
|
||||
if content == nil {
|
||||
functionResp.Response.Content = message.StringContent()
|
||||
}
|
||||
*parts = append(*parts, GeminiPart{
|
||||
FunctionResponse: functionResp,
|
||||
})
|
||||
continue
|
||||
}
|
||||
var parts []GeminiPart
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
//Parts: []GeminiPart{
|
||||
// {
|
||||
// Text: message.StringContent(),
|
||||
// },
|
||||
//},
|
||||
}
|
||||
// isToolCall := false
|
||||
if message.ToolCalls != nil {
|
||||
// message.Role = "model"
|
||||
// isToolCall = true
|
||||
for _, call := range message.ParseToolCalls() {
|
||||
args := map[string]interface{}{}
|
||||
if call.Function.Arguments != "" {
|
||||
if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
|
||||
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
|
||||
}
|
||||
}
|
||||
toolCall := GeminiPart{
|
||||
FunctionCall: &FunctionCall{
|
||||
FunctionName: call.Function.Name,
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
parts = append(parts, toolCall)
|
||||
tool_call_ids[call.ID] = call.Function.Name
|
||||
}
|
||||
}
|
||||
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == dto.ContentTypeText {
|
||||
if part.Text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == dto.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
//if imageNum > GeminiVisionMaxImageNum {
|
||||
// continue
|
||||
//}
|
||||
|
||||
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
|
||||
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
|
||||
}
|
||||
// 判断是否是url
|
||||
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||
// 是url,获取图片的类型和base64编码的数据
|
||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
continue
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: "image/" + format,
|
||||
MimeType: fileData.MimeType,
|
||||
Data: fileData.Base64Data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
|
||||
}
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: format,
|
||||
Data: base64String,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
//if content.Role == "system" {
|
||||
// content.Role = "user"
|
||||
// shouldAddDummyModelMessage = true
|
||||
//}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
//
|
||||
//// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
//if shouldAddDummyModelMessage {
|
||||
// geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
// Role: "model",
|
||||
// Parts: []GeminiPart{
|
||||
// {
|
||||
// Text: "Okay",
|
||||
// },
|
||||
// },
|
||||
// })
|
||||
// shouldAddDummyModelMessage = false
|
||||
//}
|
||||
}
|
||||
return &geminiRequest
|
||||
|
||||
if len(system_content) > 0 {
|
||||
geminiRequest.SystemInstructions = &GeminiChatContent{
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: strings.Join(system_content, "\n"),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
|
||||
if depth >= 5 {
|
||||
return schema
|
||||
}
|
||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
return g.Candidates[0].Content.Parts[0].Text
|
||||
|
||||
v, ok := schema.(map[string]interface{})
|
||||
if !ok || len(v) == 0 {
|
||||
return schema
|
||||
}
|
||||
return ""
|
||||
// 删除所有的title字段
|
||||
delete(v, "title")
|
||||
// 如果type不为object和array,则直接返回
|
||||
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
|
||||
return schema
|
||||
}
|
||||
switch v["type"] {
|
||||
case "object":
|
||||
delete(v, "additionalProperties")
|
||||
// 处理 properties
|
||||
if properties, ok := v["properties"].(map[string]interface{}); ok {
|
||||
for key, value := range properties {
|
||||
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
|
||||
}
|
||||
}
|
||||
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
|
||||
if nested, ok := v[field].([]interface{}); ok {
|
||||
for i, item := range nested {
|
||||
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "array":
|
||||
if items, ok := v["items"].(map[string]interface{}); ok {
|
||||
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||
var toolCalls []dto.ToolCall
|
||||
func unescapeString(s string) (string, error) {
|
||||
var result []rune
|
||||
escaped := false
|
||||
i := 0
|
||||
|
||||
item := candidate.Content.Parts[0]
|
||||
if item.FunctionCall == nil {
|
||||
return toolCalls
|
||||
for i < len(s) {
|
||||
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
|
||||
if r == utf8.RuneError {
|
||||
return "", fmt.Errorf("invalid UTF-8 encoding")
|
||||
}
|
||||
|
||||
if escaped {
|
||||
// 如果是转义符后的字符,检查其类型
|
||||
switch r {
|
||||
case '"':
|
||||
result = append(result, '"')
|
||||
case '\\':
|
||||
result = append(result, '\\')
|
||||
case '/':
|
||||
result = append(result, '/')
|
||||
case 'b':
|
||||
result = append(result, '\b')
|
||||
case 'f':
|
||||
result = append(result, '\f')
|
||||
case 'n':
|
||||
result = append(result, '\n')
|
||||
case 'r':
|
||||
result = append(result, '\r')
|
||||
case 't':
|
||||
result = append(result, '\t')
|
||||
case '\'':
|
||||
result = append(result, '\'')
|
||||
default:
|
||||
// 如果遇到一个非法的转义字符,直接按原样输出
|
||||
result = append(result, '\\', r)
|
||||
}
|
||||
escaped = false
|
||||
} else {
|
||||
if r == '\\' {
|
||||
escaped = true // 记录反斜杠作为转义符
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
i += size // 移动到下一个字符
|
||||
}
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
|
||||
return string(result), nil
|
||||
}
|
||||
func unescapeMapOrSlice(data interface{}) interface{} {
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, val := range v {
|
||||
v[k] = unescapeMapOrSlice(val)
|
||||
}
|
||||
case []interface{}:
|
||||
for i, val := range v {
|
||||
v[i] = unescapeMapOrSlice(val)
|
||||
}
|
||||
case string:
|
||||
if unescaped, err := unescapeString(v); err != nil {
|
||||
return v
|
||||
} else {
|
||||
return unescaped
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func getToolCall(item *GeminiPart) *dto.ToolCall {
|
||||
var argsBytes []byte
|
||||
var err error
|
||||
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
|
||||
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
|
||||
} else {
|
||||
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
//common.SysError("getToolCalls failed: " + err.Error())
|
||||
return toolCalls
|
||||
return nil
|
||||
}
|
||||
toolCall := dto.ToolCall{
|
||||
return &dto.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", common.GetUUID()),
|
||||
Type: "function",
|
||||
Function: dto.FunctionCall{
|
||||
@@ -194,8 +369,6 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
|
||||
Name: item.FunctionCall.FunctionName,
|
||||
},
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
@@ -206,9 +379,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
for i, candidate := range response.Candidates {
|
||||
is_tool_call := false
|
||||
for _, candidate := range response.Candidates {
|
||||
choice := dto.OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Index: int(candidate.Index),
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
@@ -216,48 +390,116 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
|
||||
FinishReason: constant.FinishReasonStop,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||
} else {
|
||||
var texts []string
|
||||
for _, part := range candidate.Content.Parts {
|
||||
texts = append(texts, part.Text)
|
||||
var texts []string
|
||||
var tool_calls []dto.ToolCall
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
if call := getToolCall(&part); call != nil {
|
||||
tool_calls = append(tool_calls, *call)
|
||||
}
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
|
||||
} else if part.CodeExecutionResult != nil {
|
||||
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
|
||||
} else {
|
||||
// 过滤掉空行
|
||||
if part.Text != "\n" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
}
|
||||
if len(tool_calls) > 0 {
|
||||
choice.Message.SetToolCalls(tool_calls)
|
||||
is_tool_call = true
|
||||
}
|
||||
|
||||
choice.Message.SetStringContent(strings.Join(texts, "\n"))
|
||||
|
||||
}
|
||||
if candidate.FinishReason != nil {
|
||||
switch *candidate.FinishReason {
|
||||
case "STOP":
|
||||
choice.FinishReason = constant.FinishReasonStop
|
||||
case "MAX_TOKENS":
|
||||
choice.FinishReason = constant.FinishReasonLength
|
||||
default:
|
||||
choice.FinishReason = constant.FinishReasonContentFilter
|
||||
}
|
||||
}
|
||||
if is_tool_call {
|
||||
choice.FinishReason = constant.FinishReasonToolCalls
|
||||
}
|
||||
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
|
||||
var choice dto.ChatCompletionsStreamResponseChoice
|
||||
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
|
||||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
|
||||
respFirstParts := geminiResponse.Candidates[0].Content.Parts
|
||||
if respFirstParts[0].FunctionCall != nil {
|
||||
// function response
|
||||
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
|
||||
} else {
|
||||
// text response
|
||||
var texts []string
|
||||
for _, part := range respFirstParts {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
|
||||
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
|
||||
is_stop := false
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
|
||||
is_stop = true
|
||||
candidate.FinishReason = nil
|
||||
}
|
||||
choice := dto.ChatCompletionsStreamResponseChoice{
|
||||
Index: int(candidate.Index),
|
||||
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||
Role: "assistant",
|
||||
},
|
||||
}
|
||||
var texts []string
|
||||
isTools := false
|
||||
if candidate.FinishReason != nil {
|
||||
// p := GeminiConvertFinishReason(*candidate.FinishReason)
|
||||
switch *candidate.FinishReason {
|
||||
case "STOP":
|
||||
choice.FinishReason = &constant.FinishReasonStop
|
||||
case "MAX_TOKENS":
|
||||
choice.FinishReason = &constant.FinishReasonLength
|
||||
default:
|
||||
choice.FinishReason = &constant.FinishReasonContentFilter
|
||||
}
|
||||
}
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
isTools = true
|
||||
if call := getToolCall(&part); call != nil {
|
||||
call.SetIndex(len(choice.Delta.ToolCalls))
|
||||
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
|
||||
}
|
||||
} else {
|
||||
if part.ExecutableCode != nil {
|
||||
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
|
||||
} else if part.CodeExecutionResult != nil {
|
||||
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
|
||||
} else {
|
||||
if part.Text != "\n" {
|
||||
texts = append(texts, part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
choice.Delta.SetContentString(strings.Join(texts, "\n"))
|
||||
if isTools {
|
||||
choice.FinishReason = &constant.FinishReasonToolCalls
|
||||
}
|
||||
choices = append(choices, choice)
|
||||
}
|
||||
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
response.Choices = choices
|
||||
return &response, is_stop
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseText := ""
|
||||
// responseText := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
@@ -281,13 +523,11 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
continue
|
||||
}
|
||||
|
||||
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
if response == nil {
|
||||
continue
|
||||
}
|
||||
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
responseText += response.Choices[0].Delta.GetContentString()
|
||||
response.Model = info.UpstreamModelName
|
||||
// responseText += response.Choices[0].Delta.GetContentString()
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
@@ -296,12 +536,17 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
if is_stop {
|
||||
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
service.ObjectData(c, response)
|
||||
}
|
||||
}
|
||||
|
||||
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
|
||||
service.ObjectData(c, response)
|
||||
var response *dto.ChatCompletionsStreamResponse
|
||||
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
@@ -315,7 +560,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
@@ -341,6 +586,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
|
||||
@@ -106,15 +106,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
if info.ChannelType != common.ChannelTypeOpenAI {
|
||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o1-") {
|
||||
if strings.HasPrefix(request.Model, "o1") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
request.MaxTokens = 0
|
||||
}
|
||||
}
|
||||
if request.Model == "o1" || request.Model == "o1-2024-12-17" {
|
||||
//修改第一个Message的内容,将system改为developer
|
||||
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
|
||||
request.Messages[0].Role = "developer"
|
||||
}
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ var ModelList = []string{
|
||||
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
|
||||
"o1-preview", "o1-preview-2024-09-12",
|
||||
"o1-mini", "o1-mini-2024-09-12",
|
||||
"o1", "o1-2024-12-17",
|
||||
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
|
||||
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
|
||||
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
|
||||
|
||||
@@ -135,7 +135,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
c.Set("request_model", request.Model)
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest := gemini.CovertGemini2OpenAI(*request)
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Set("request_model", request.Model)
|
||||
return geminiRequest, nil
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
@@ -167,7 +170,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case RequestModeClaude:
|
||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||
case RequestModeGemini:
|
||||
err, usage = gemini.GeminiChatHandler(c, resp)
|
||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||
case RequestModeLlama:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
|
||||
}
|
||||
|
||||
@@ -2,8 +2,9 @@ package common
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/constant"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -66,13 +67,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := time.Now()
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &RelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
@@ -108,7 +109,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||
info.ChannelType == common.ChannelCloudflare {
|
||||
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
return info
|
||||
@@ -158,10 +159,10 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
group := c.GetString("group")
|
||||
startTime := time.Now()
|
||||
|
||||
apiType, _ := constant.ChannelType2APIType(channelType)
|
||||
apiType, _ := relayconstant.ChannelType2APIType(channelType)
|
||||
|
||||
info := &TaskRelayInfo{
|
||||
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
|
||||
BaseUrl: c.GetString("base_url"),
|
||||
RequestURLPath: c.Request.URL.String(),
|
||||
ChannelType: channelType,
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
|
||||
@@ -26,7 +26,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
if audioRequest.Model == "" {
|
||||
return nil, errors.New("model is required")
|
||||
}
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err := service.CheckSensitiveInput(audioRequest.Input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -74,30 +74,16 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(audioRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
defer func() {
|
||||
if openaiErr != nil {
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||
//}
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err := service.CheckSensitiveInput(imageRequest.Prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -99,8 +99,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||
modelPrice = 0.0025 * modelRatio
|
||||
}
|
||||
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
|
||||
sizeRatio := 1.0
|
||||
// Size
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -111,8 +112,8 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
@@ -167,9 +168,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -193,11 +194,11 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
defer func(ctx context.Context) {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
//err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
@@ -207,7 +208,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
@@ -421,7 +423,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
if originTask == nil {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||||
if constant.MjActionCheckSuccessEnabled {
|
||||
if setting.MjActionCheckSuccessEnabled {
|
||||
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||||
}
|
||||
@@ -472,9 +474,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -498,21 +500,18 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -93,24 +94,31 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
if constant.ShouldCheckPromptSensitive() {
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
err = checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
promptTokens, err := getPromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
var promptTokens int
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
promptTokens = value.(int)
|
||||
} else {
|
||||
promptTokens, err = getPromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
}
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
if !getModelPriceSuccess {
|
||||
@@ -222,7 +230,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
var err error
|
||||
switch info.RelayMode {
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
|
||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
@@ -254,7 +262,7 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
|
||||
|
||||
// 预扣费并返回用户剩余配额
|
||||
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -264,10 +272,6 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
@@ -285,11 +289,16 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
||||
}
|
||||
}
|
||||
|
||||
if preConsumedQuota > 0 {
|
||||
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
return preConsumedQuota, userQuota, nil
|
||||
}
|
||||
@@ -299,7 +308,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
||||
go func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
|
||||
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysError("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
@@ -357,15 +366,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
@@ -384,7 +389,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
|
||||
//if quota != 0 {
|
||||
//
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
|
||||
relayInfo.UpstreamModelName = rerankRequest.Model
|
||||
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -48,9 +49,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -112,21 +113,18 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
// release quota
|
||||
if relayInfo.ConsumeQuota && taskErr == nil {
|
||||
|
||||
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
|
||||
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
|
||||
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
|
||||
@@ -28,10 +28,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
|
||||
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
@@ -98,7 +98,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
channelRoute.POST("/fix", controller.FixChannelsAbilities)
|
||||
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
|
||||
|
||||
channelRoute.POST("/fetch_models", controller.FetchModels)
|
||||
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
tokenRoute.Use(middleware.UserAuth())
|
||||
|
||||
29
service/cf_worker.go
Normal file
29
service/cf_worker.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if setting.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
|
||||
if !strings.HasPrefix(originUrl, "https") {
|
||||
return nil, fmt.Errorf("only support https url")
|
||||
}
|
||||
workerUrl := setting.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
// post request to worker
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
|
||||
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if constant.CustomCallbackAddress == "" {
|
||||
return constant.ServerAddress
|
||||
if setting.CustomCallbackAddress == "" {
|
||||
return setting.ServerAddress
|
||||
}
|
||||
return constant.CustomCallbackAddress
|
||||
return setting.CustomCallbackAddress
|
||||
}
|
||||
|
||||
39
service/file_decoder.go
Normal file
39
service/file_decoder.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
)
|
||||
|
||||
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
|
||||
|
||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
resp, err := DoDownloadRequest(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Always use LimitReader to prevent oversized downloads
|
||||
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check actual size after reading
|
||||
if len(fileBytes) > maxFileSize {
|
||||
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
|
||||
}
|
||||
|
||||
// Convert to base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
return &dto.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: resp.Header.Get("Content-Type"),
|
||||
Size: int64(len(fileBytes)),
|
||||
}, nil
|
||||
}
|
||||
@@ -5,11 +5,12 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/image/webp"
|
||||
"image"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
@@ -31,14 +32,39 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
|
||||
return config, format, base64String, err
|
||||
}
|
||||
|
||||
func DecodeBase64FileData(base64String string) (string, string, error) {
|
||||
var mimeType string
|
||||
var idx int
|
||||
idx = strings.Index(base64String, ",")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = base64String[:idx]
|
||||
base64String = base64String[idx+1:]
|
||||
idx = strings.Index(mimeType, ";")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = mimeType[:idx]
|
||||
idx = strings.Index(mimeType, ":")
|
||||
if idx == -1 {
|
||||
_, file_type, base64, err := DecodeBase64ImageData(base64String)
|
||||
return "image/" + file_type, base64, err
|
||||
}
|
||||
mimeType = mimeType[idx+1:]
|
||||
return mimeType, base64String, nil
|
||||
}
|
||||
|
||||
// GetImageFromUrl 获取图片的类型和base64编码的数据
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
resp, err := DoImageRequest(url)
|
||||
resp, err := DoDownloadRequest(url)
|
||||
if err != nil {
|
||||
return
|
||||
return "", "", err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return
|
||||
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
@@ -52,7 +78,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := DoImageRequest(imageUrl)
|
||||
response, err := DoDownloadRequest(imageUrl)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, "", err
|
||||
@@ -64,6 +90,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
mimeType := response.Header.Get("Content-Type")
|
||||
|
||||
if !strings.HasPrefix(mimeType, "image/") {
|
||||
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
|
||||
}
|
||||
|
||||
var readData []byte
|
||||
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
|
||||
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -167,16 +168,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
if !constant.MjAccountFilterEnabled {
|
||||
if !setting.MjAccountFilterEnabled {
|
||||
delete(mapResult, "accountFilter")
|
||||
}
|
||||
if !constant.MjNotifyEnabled {
|
||||
if !setting.MjNotifyEnabled {
|
||||
delete(mapResult, "notifyHook")
|
||||
}
|
||||
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
// make new request with mapResult
|
||||
}
|
||||
if constant.MjModeClearEnabled {
|
||||
if setting.MjModeClearEnabled {
|
||||
if prompt, ok := mapResult["prompt"].(string); ok {
|
||||
prompt = strings.Replace(prompt, "--fast", "", -1)
|
||||
prompt = strings.Replace(prompt, "--relax", "", -1)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -17,12 +18,12 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
if relayInfo.UsePrice {
|
||||
return nil
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId)
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"))
|
||||
token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
completionRatio := common.GetCompletionRatio(modelName)
|
||||
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
|
||||
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
|
||||
groupRatio := common.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
modelRatio := common.GetModelRatio(modelName)
|
||||
|
||||
ratio := groupRatio * modelRatio
|
||||
@@ -57,15 +58,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
|
||||
}
|
||||
|
||||
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
|
||||
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
|
||||
err = model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -119,7 +116,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
//}
|
||||
//quotaDelta := quota - preConsumedQuota
|
||||
//if quotaDelta != 0 {
|
||||
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
// err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
// if err != nil {
|
||||
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
// }
|
||||
@@ -139,7 +136,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -189,15 +186,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
} else {
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
err := model.CacheUpdateUserQuota(relayInfo.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
@@ -208,5 +201,5 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ func CheckSensitiveInput(input any) error {
|
||||
|
||||
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
||||
func SensitiveWordContains(text string) (bool, []string) {
|
||||
if len(constant.SensitiveWords) == 0 {
|
||||
if len(setting.SensitiveWords) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
@@ -75,7 +75,7 @@ func SensitiveWordContains(text string) (bool, []string) {
|
||||
|
||||
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
|
||||
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
|
||||
if len(constant.SensitiveWords) == 0 {
|
||||
if len(setting.SensitiveWords) == 0 {
|
||||
return false, nil, text
|
||||
}
|
||||
checkText := strings.ToLower(text)
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
goahocorasick "github.com/anknown/ahocorasick"
|
||||
"one-api/constant"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -70,7 +70,7 @@ func InitAc() *goahocorasick.Machine {
|
||||
func readRunes() [][]rune {
|
||||
var dict [][]rune
|
||||
|
||||
for _, word := range constant.SensitiveWords {
|
||||
for _, word := range setting.SensitiveWords {
|
||||
word = strings.ToLower(word)
|
||||
l := bytes.TrimSpace([]byte(word))
|
||||
dict = append(dict, bytes.Runes(l))
|
||||
|
||||
@@ -19,42 +19,40 @@ import (
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
var cl200kTokenEncoder *tiktoken.Tiktoken
|
||||
var o200kTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||
cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = gpt35TokenEncoder
|
||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
|
||||
defaultTokenEncoder = cl100TokenEncoder
|
||||
o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.GetDefaultModelRatioMap() {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
tokenEncoderMap[model] = cl100TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = cl200kTokenEncoder
|
||||
tokenEncoderMap[model] = o200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
}
|
||||
} else if strings.HasPrefix(model, "o1") {
|
||||
tokenEncoderMap[model] = o200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
}
|
||||
}
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
|
||||
return cl200kTokenEncoder
|
||||
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
|
||||
return o200kTokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
@@ -82,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
baseTokens := 85
|
||||
if model == "glm-4v" {
|
||||
return 1047, nil
|
||||
@@ -92,11 +90,14 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
}
|
||||
// TODO: 非流模式下不计算图片token数量
|
||||
if !constant.GetMediaTokenNotStream && !stream {
|
||||
return 1000, nil
|
||||
return 256, nil
|
||||
}
|
||||
// 是否统计图片token
|
||||
if !constant.GetMediaToken {
|
||||
return 1000, nil
|
||||
return 256, nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
|
||||
return 256, nil
|
||||
}
|
||||
// 同步One API的图片计费逻辑
|
||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||
@@ -157,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
|
||||
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
tkm := 0
|
||||
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
|
||||
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -181,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err := CountTokenInput(countStr, model)
|
||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -258,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
|
||||
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
@@ -292,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
|
||||
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
|
||||
if constant.EnableWorker() {
|
||||
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
|
||||
workerUrl := constant.WorkerUrl
|
||||
if !strings.HasSuffix(workerUrl, "/") {
|
||||
workerUrl += "/"
|
||||
}
|
||||
// post request to worker
|
||||
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
|
||||
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
|
||||
return http.Get(originUrl)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,33 +1,47 @@
|
||||
package common
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var GroupRatio = map[string]float64{
|
||||
var groupRatio = map[string]float64{
|
||||
"default": 1,
|
||||
"vip": 1,
|
||||
"svip": 1,
|
||||
}
|
||||
|
||||
func GetGroupRatioCopy() map[string]float64 {
|
||||
groupRatioCopy := make(map[string]float64)
|
||||
for k, v := range groupRatio {
|
||||
groupRatioCopy[k] = v
|
||||
}
|
||||
return groupRatioCopy
|
||||
}
|
||||
|
||||
func ContainsGroupRatio(name string) bool {
|
||||
_, ok := groupRatio[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func GroupRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(GroupRatio)
|
||||
jsonBytes, err := json.Marshal(groupRatio)
|
||||
if err != nil {
|
||||
SysError("error marshalling model ratio: " + err.Error())
|
||||
common.SysError("error marshalling model ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||
GroupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
|
||||
groupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &groupRatio)
|
||||
}
|
||||
|
||||
func GetGroupRatio(name string) float64 {
|
||||
ratio, ok := GroupRatio[name]
|
||||
ratio, ok := groupRatio[name]
|
||||
if !ok {
|
||||
SysError("group ratio not found: " + name)
|
||||
common.SysError("group ratio not found: " + name)
|
||||
return 1
|
||||
}
|
||||
return ratio
|
||||
7
setting/midjourney.go
Normal file
7
setting/midjourney.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package setting
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
var MjAccountFilterEnabled = false
|
||||
var MjModeClearEnabled = false
|
||||
var MjForwardUrlEnabled = true
|
||||
var MjActionCheckSuccessEnabled = true
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
import "strings"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package constant
|
||||
package setting
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
52
setting/user_usable_group.go
Normal file
52
setting/user_usable_group.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var userUsableGroups = map[string]string{
|
||||
"default": "默认分组",
|
||||
"vip": "vip分组",
|
||||
}
|
||||
|
||||
func GetUserUsableGroupsCopy() map[string]string {
|
||||
copyUserUsableGroups := make(map[string]string)
|
||||
for k, v := range userUsableGroups {
|
||||
copyUserUsableGroups[k] = v
|
||||
}
|
||||
return copyUserUsableGroups
|
||||
}
|
||||
|
||||
func UserUsableGroups2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(userUsableGroups)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling user groups: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
|
||||
userUsableGroups = make(map[string]string)
|
||||
return json.Unmarshal([]byte(jsonStr), &userUsableGroups)
|
||||
}
|
||||
|
||||
func GetUserUsableGroups(userGroup string) map[string]string {
|
||||
groupsCopy := GetUserUsableGroupsCopy()
|
||||
if userGroup == "" {
|
||||
if _, ok := groupsCopy["default"]; !ok {
|
||||
groupsCopy["default"] = "default"
|
||||
}
|
||||
}
|
||||
// 如果userGroup不在UserUsableGroups中,返回UserUsableGroups + userGroup
|
||||
if _, ok := groupsCopy[userGroup]; !ok {
|
||||
groupsCopy[userGroup] = "用户分组"
|
||||
}
|
||||
// 如果userGroup在UserUsableGroups中,返回UserUsableGroups
|
||||
return groupsCopy
|
||||
}
|
||||
|
||||
func GroupInUserUsableGroups(groupName string) bool {
|
||||
_, ok := userUsableGroups[groupName]
|
||||
return ok
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<html lang="zh">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="/logo.png" />
|
||||
|
||||
@@ -162,9 +162,15 @@ const ChannelsTable = () => {
|
||||
return (
|
||||
<div>
|
||||
<Space spacing={2}>
|
||||
{text?.split(',').map((item, index) => {
|
||||
return renderGroup(item);
|
||||
})}
|
||||
{text?.split(',')
|
||||
.sort((a, b) => {
|
||||
if (a === 'default') return -1;
|
||||
if (b === 'default') return 1;
|
||||
return a.localeCompare(b);
|
||||
})
|
||||
.map((item, index) => {
|
||||
return renderGroup(item);
|
||||
})}
|
||||
</Space>
|
||||
</div>
|
||||
);
|
||||
@@ -507,6 +513,8 @@ const ChannelsTable = () => {
|
||||
const [selectedChannels, setSelectedChannels] = useState([]);
|
||||
const [showEditPriority, setShowEditPriority] = useState(false);
|
||||
const [enableTagMode, setEnableTagMode] = useState(false);
|
||||
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
|
||||
const [batchSetTagValue, setBatchSetTagValue] = useState('');
|
||||
|
||||
|
||||
const removeRecord = (record) => {
|
||||
@@ -968,6 +976,29 @@ const ChannelsTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const batchSetChannelTag = async () => {
|
||||
if (selectedChannels.length === 0) {
|
||||
showError(t('请先选择要设置标签的渠道!'));
|
||||
return;
|
||||
}
|
||||
if (batchSetTagValue === '') {
|
||||
showError(t('标签不能为空!'));
|
||||
return;
|
||||
}
|
||||
let ids = selectedChannels.map(channel => channel.id);
|
||||
const res = await API.post('/api/channel/batch/tag', {
|
||||
ids: ids,
|
||||
tag: batchSetTagValue === '' ? null : batchSetTagValue
|
||||
});
|
||||
if (res.data.success) {
|
||||
showSuccess(t('已为 ${count} 个渠道设置标签!').replace('${count}', res.data.data));
|
||||
await refresh();
|
||||
setShowBatchSetTag(false);
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<EditTagModal
|
||||
@@ -1115,11 +1146,11 @@ const ChannelsTable = () => {
|
||||
</div>
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Space>
|
||||
<Typography.Text strong>{t('开启批量删除')}</Typography.Text>
|
||||
<Typography.Text strong>{t('开启批量操作')}</Typography.Text>
|
||||
<Switch
|
||||
label={t('开启批量删除')}
|
||||
label={t('开启批量操作')}
|
||||
uncheckedText={t('关')}
|
||||
aria-label={t('是否开启批量删除')}
|
||||
aria-label={t('是否开启批量操作')}
|
||||
onChange={(v) => {
|
||||
setEnableBatchDelete(v);
|
||||
}}
|
||||
@@ -1167,7 +1198,17 @@ const ChannelsTable = () => {
|
||||
loadChannels(0, pageSize, idSort, v);
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
disabled={!enableBatchDelete}
|
||||
theme="light"
|
||||
type="primary"
|
||||
style={{ marginRight: 8 }}
|
||||
onClick={() => setShowBatchSetTag(true)}
|
||||
>
|
||||
{t('批量设置标签')}
|
||||
</Button>
|
||||
</Space>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@@ -1201,6 +1242,23 @@ const ChannelsTable = () => {
|
||||
: null
|
||||
}
|
||||
/>
|
||||
<Modal
|
||||
title={t('批量设置标签')}
|
||||
visible={showBatchSetTag}
|
||||
onOk={batchSetChannelTag}
|
||||
onCancel={() => setShowBatchSetTag(false)}
|
||||
maskClosable={false}
|
||||
centered={true}
|
||||
>
|
||||
<div style={{ marginBottom: 20 }}>
|
||||
<Typography.Text>{t('请输入要设置的标签名称')}</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
placeholder={t('请输入标签名称')}
|
||||
value={batchSetTagValue}
|
||||
onChange={(v) => setBatchSetTagValue(v)}
|
||||
/>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -199,7 +199,7 @@ const HeaderBar = () => {
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Nav.Item itemKey={'new-year'} text={'🏮'} />
|
||||
<Nav.Item itemKey={'new-year'} text={'🎉'} />
|
||||
</Dropdown>
|
||||
)}
|
||||
{/* <Nav.Item itemKey={'about'} icon={<IconHelpCircle />} /> */}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
API,
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { ITEMS_PER_PAGE } from '../constants';
|
||||
import {
|
||||
renderAudioModelPrice,
|
||||
renderAudioModelPrice, renderGroup,
|
||||
renderModelPrice, renderModelPriceSimple,
|
||||
renderNumber,
|
||||
renderQuota,
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
} from '../helpers/render';
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
import { getLogOther } from '../helpers/other.js';
|
||||
import { StyleContext } from '../context/Style/index.js';
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
@@ -184,7 +185,10 @@ const LogsTable = () => {
|
||||
size='small'
|
||||
color={stringToColor(text)}
|
||||
style={{ marginRight: 4 }}
|
||||
onClick={() => showUserInfo(record.user_id)}
|
||||
onClick={(event) => {
|
||||
event.stopPropagation();
|
||||
showUserInfo(record.user_id)
|
||||
}}
|
||||
>
|
||||
{typeof text === 'string' && text.slice(0, 1)}
|
||||
</Avatar>
|
||||
@@ -204,8 +208,9 @@ const LogsTable = () => {
|
||||
<Tag
|
||||
color='grey'
|
||||
size='large'
|
||||
onClick={() => {
|
||||
copyText(text);
|
||||
onClick={(event) => {
|
||||
//cancel the row click event
|
||||
copyText(event, text);
|
||||
}}
|
||||
>
|
||||
{' '}
|
||||
@@ -217,6 +222,37 @@ const LogsTable = () => {
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('分组'),
|
||||
dataIndex: 'group',
|
||||
render: (text, record, index) => {
|
||||
if (record.type === 0 || record.type === 2) {
|
||||
if (record.group) {
|
||||
return (
|
||||
<>
|
||||
{renderGroup(record.group)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
let other = JSON.parse(record.other);
|
||||
if (other === null) {
|
||||
return <></>;
|
||||
}
|
||||
if (other.group !== undefined) {
|
||||
return (
|
||||
<>
|
||||
{renderGroup(other.group)}
|
||||
</>
|
||||
);
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return <></>;
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('类型'),
|
||||
dataIndex: 'type',
|
||||
@@ -233,8 +269,8 @@ const LogsTable = () => {
|
||||
<Tag
|
||||
color={stringToColor(text)}
|
||||
size='large'
|
||||
onClick={() => {
|
||||
copyText(text);
|
||||
onClick={(event) => {
|
||||
copyText(event, text);
|
||||
}}
|
||||
>
|
||||
{' '}
|
||||
@@ -375,6 +411,7 @@ const LogsTable = () => {
|
||||
},
|
||||
];
|
||||
|
||||
const [styleState, styleDispatch] = useContext(StyleContext);
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [expandData, setExpandData] = useState({});
|
||||
const [showStat, setShowStat] = useState(false);
|
||||
@@ -394,6 +431,7 @@ const LogsTable = () => {
|
||||
start_timestamp: timestamp2string(getTodayStartTimestamp()),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: '',
|
||||
group: '',
|
||||
});
|
||||
const {
|
||||
username,
|
||||
@@ -402,6 +440,7 @@ const LogsTable = () => {
|
||||
start_timestamp,
|
||||
end_timestamp,
|
||||
channel,
|
||||
group,
|
||||
} = inputs;
|
||||
|
||||
const [stat, setStat] = useState({
|
||||
@@ -410,13 +449,13 @@ const LogsTable = () => {
|
||||
});
|
||||
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
setInputs(inputs => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
|
||||
const getLogSelfStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
|
||||
url = encodeURI(url);
|
||||
let res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
@@ -430,7 +469,7 @@ const LogsTable = () => {
|
||||
const getLogStat = async () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
|
||||
url = encodeURI(url);
|
||||
let res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
@@ -483,7 +522,7 @@ const LogsTable = () => {
|
||||
let expandDatesLocal = {};
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = i;
|
||||
logs[i].key = logs[i].id;
|
||||
let other = getLogOther(logs[i].other);
|
||||
let expandDataLocal = [];
|
||||
if (isAdmin()) {
|
||||
@@ -573,9 +612,9 @@ const LogsTable = () => {
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
|
||||
} else {
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
|
||||
}
|
||||
url = encodeURI(url);
|
||||
const res = await API.get(url);
|
||||
@@ -615,11 +654,12 @@ const LogsTable = () => {
|
||||
await loadLogs(activePage, pageSize, logType);
|
||||
};
|
||||
|
||||
const copyText = async (text) => {
|
||||
const copyText = async (e, text) => {
|
||||
e.stopPropagation();
|
||||
if (await copy(text)) {
|
||||
showSuccess('已复制:' + text);
|
||||
} else {
|
||||
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
|
||||
Modal.error({ title: t('无法复制到剪贴板,请手动复制'), content: text });
|
||||
}
|
||||
};
|
||||
|
||||
@@ -659,10 +699,53 @@ const LogsTable = () => {
|
||||
</Header>
|
||||
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
||||
<>
|
||||
<Form.Section>
|
||||
<div style={{ marginBottom: 10 }}>
|
||||
{
|
||||
styleState.isMobile ? (
|
||||
<div>
|
||||
<Form.DatePicker
|
||||
field='start_timestamp'
|
||||
label={t('起始时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
type='dateTime'
|
||||
onChange={(value) => {
|
||||
console.log(value);
|
||||
handleInputChange(value, 'start_timestamp')
|
||||
}}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='end_timestamp'
|
||||
fluid
|
||||
label={t('结束时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
type='dateTime'
|
||||
onChange={(value) => handleInputChange(value, 'end_timestamp')}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Form.DatePicker
|
||||
field="range_timestamp"
|
||||
label={t('时间范围')}
|
||||
initValue={[start_timestamp, end_timestamp]}
|
||||
type="dateTimeRange"
|
||||
name="range_timestamp"
|
||||
onChange={(value) => {
|
||||
if (Array.isArray(value) && value.length === 2) {
|
||||
handleInputChange(value[0], 'start_timestamp');
|
||||
handleInputChange(value[1], 'end_timestamp');
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</Form.Section>
|
||||
<Form.Input
|
||||
field='token_name'
|
||||
label={t('令牌名称')}
|
||||
style={{ width: 176 }}
|
||||
value={token_name}
|
||||
placeholder={t('可选值')}
|
||||
name='token_name'
|
||||
@@ -671,39 +754,24 @@ const LogsTable = () => {
|
||||
<Form.Input
|
||||
field='model_name'
|
||||
label={t('模型名称')}
|
||||
style={{ width: 176 }}
|
||||
value={model_name}
|
||||
placeholder={t('可选值')}
|
||||
name='model_name'
|
||||
onChange={(value) => handleInputChange(value, 'model_name')}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='start_timestamp'
|
||||
label={t('起始时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp}
|
||||
type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={(value) => handleInputChange(value, 'start_timestamp')}
|
||||
/>
|
||||
<Form.DatePicker
|
||||
field='end_timestamp'
|
||||
fluid
|
||||
label={t('结束时间')}
|
||||
style={{ width: 272 }}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp}
|
||||
type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={(value) => handleInputChange(value, 'end_timestamp')}
|
||||
<Form.Input
|
||||
field='group'
|
||||
label={t('分组')}
|
||||
value={group}
|
||||
placeholder={t('可选值')}
|
||||
name='group'
|
||||
onChange={(value) => handleInputChange(value, 'group')}
|
||||
/>
|
||||
{isAdminUser && (
|
||||
<>
|
||||
<Form.Input
|
||||
field='channel'
|
||||
label={t('渠道 ID')}
|
||||
style={{ width: 176 }}
|
||||
value={channel}
|
||||
placeholder={t('可选值')}
|
||||
name='channel'
|
||||
@@ -712,7 +780,6 @@ const LogsTable = () => {
|
||||
<Form.Input
|
||||
field='username'
|
||||
label={t('用户名称')}
|
||||
style={{ width: 176 }}
|
||||
value={username}
|
||||
placeholder={t('可选值')}
|
||||
name='username'
|
||||
|
||||
@@ -81,41 +81,24 @@ const ModelPricing = () => {
|
||||
}
|
||||
|
||||
function renderAvailable(available) {
|
||||
return available ? (
|
||||
return (
|
||||
<Popover
|
||||
content={
|
||||
<div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
|
||||
}
|
||||
position='top'
|
||||
key={available}
|
||||
style={{
|
||||
backgroundColor: 'rgba(var(--semi-blue-4),1)',
|
||||
borderColor: 'rgba(var(--semi-blue-4),1)',
|
||||
color: 'var(--semi-color-white)',
|
||||
borderWidth: 1,
|
||||
borderStyle: 'solid',
|
||||
}}
|
||||
content={
|
||||
<div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
|
||||
}
|
||||
position='top'
|
||||
key={available}
|
||||
style={{
|
||||
backgroundColor: 'rgba(var(--semi-blue-4),1)',
|
||||
borderColor: 'rgba(var(--semi-blue-4),1)',
|
||||
color: 'var(--semi-color-white)',
|
||||
borderWidth: 1,
|
||||
borderStyle: 'solid',
|
||||
}}
|
||||
>
|
||||
<IconVerify style={{ color: 'green' }} size="large" />
|
||||
<IconVerify style={{ color: 'green' }} size="large" />
|
||||
</Popover>
|
||||
) : (
|
||||
<Popover
|
||||
content={
|
||||
<div style={{ padding: 8 }}>{t('您的分组无权使用该模型')}</div>
|
||||
}
|
||||
position='top'
|
||||
key={available}
|
||||
style={{
|
||||
backgroundColor: 'rgba(var(--semi-blue-4),1)',
|
||||
borderColor: 'rgba(var(--semi-blue-4),1)',
|
||||
color: 'var(--semi-color-white)',
|
||||
borderWidth: 1,
|
||||
borderStyle: 'solid',
|
||||
}}
|
||||
>
|
||||
<IconUploadError style={{ color: '#FFA54F' }} size="large" />
|
||||
</Popover>
|
||||
);
|
||||
)
|
||||
}
|
||||
|
||||
const columns = [
|
||||
@@ -162,36 +145,39 @@ const ModelPricing = () => {
|
||||
title: t('可用分组'),
|
||||
dataIndex: 'enable_groups',
|
||||
render: (text, record, index) => {
|
||||
|
||||
// enable_groups is a string array
|
||||
return (
|
||||
<Space>
|
||||
{text.map((group) => {
|
||||
if (group === selectedGroup) {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
prefixIcon={<IconVerify />}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
onClick={() => {
|
||||
setSelectedGroup(group);
|
||||
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
|
||||
group: group,
|
||||
ratio: groupRatio[group]
|
||||
}));
|
||||
}}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
if (usableGroup[group]) {
|
||||
if (group === selectedGroup) {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
prefixIcon={<IconVerify />}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<Tag
|
||||
color='blue'
|
||||
size='large'
|
||||
onClick={() => {
|
||||
setSelectedGroup(group);
|
||||
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
|
||||
group: group,
|
||||
ratio: groupRatio[group]
|
||||
}));
|
||||
}}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
}
|
||||
})}
|
||||
</Space>
|
||||
@@ -275,6 +261,7 @@ const ModelPricing = () => {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [groupRatio, setGroupRatio] = useState({});
|
||||
const [usableGroup, setUsableGroup] = useState({});
|
||||
|
||||
const setModelsFormat = (models, groupRatio) => {
|
||||
for (let i = 0; i < models.length; i++) {
|
||||
@@ -309,9 +296,10 @@ const ModelPricing = () => {
|
||||
let url = '';
|
||||
url = `/api/pricing`;
|
||||
const res = await API.get(url);
|
||||
const { success, message, data, group_ratio } = res.data;
|
||||
const { success, message, data, group_ratio, usable_group } = res.data;
|
||||
if (success) {
|
||||
setGroupRatio(group_ratio);
|
||||
setUsableGroup(usable_group);
|
||||
setSelectedGroup(userState.user ? userState.user.group : 'default')
|
||||
setModelsFormat(data, group_ratio);
|
||||
} else {
|
||||
|
||||
@@ -146,8 +146,9 @@ const PersonalSetting = () => {
|
||||
let res = await API.get(`/api/user/models`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
setModels(data);
|
||||
console.log(data);
|
||||
if (data != null) {
|
||||
setModels(data);
|
||||
}
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
|
||||
@@ -406,7 +406,7 @@ const UsersTable = () => {
|
||||
if (searchKeyword === '') {
|
||||
await loadUsers(activePage - 1);
|
||||
} else {
|
||||
await searchUsers();
|
||||
await searchUsers(searchKeyword, searchGroup);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import i18next from 'i18next';
|
||||
import { Tag } from '@douyinfe/semi-ui';
|
||||
import { Modal, Tag, Typography } from '@douyinfe/semi-ui';
|
||||
import { copy, showSuccess } from './utils.js';
|
||||
|
||||
export function renderText(text, limit) {
|
||||
if (text.length > limit) {
|
||||
@@ -38,6 +39,14 @@ export function renderGroup(group) {
|
||||
size='large'
|
||||
color={tagColors[group] || stringToColor(group)}
|
||||
key={group}
|
||||
onClick={async (event) => {
|
||||
event.stopPropagation();
|
||||
if (await copy(group)) {
|
||||
showSuccess(i18next.t('已复制:') + group);
|
||||
} else {
|
||||
Modal.error({ title: t('无法复制到剪贴板,请手动复制'), content: group });
|
||||
}
|
||||
}}
|
||||
>
|
||||
{group}
|
||||
</Tag>
|
||||
@@ -46,6 +55,81 @@ export function renderGroup(group) {
|
||||
);
|
||||
}
|
||||
|
||||
export function renderRatio(ratio) {
|
||||
let color = 'green';
|
||||
if (ratio > 5) {
|
||||
color = 'red';
|
||||
} else if (ratio > 3) {
|
||||
color = 'orange';
|
||||
} else if (ratio > 1) {
|
||||
color = 'blue';
|
||||
}
|
||||
return <Tag color={color}>{ratio}x {i18next.t('倍率')}</Tag>;
|
||||
}
|
||||
|
||||
export const renderGroupOption = (item) => {
|
||||
const {
|
||||
disabled,
|
||||
selected,
|
||||
label,
|
||||
value,
|
||||
focused,
|
||||
className,
|
||||
style,
|
||||
onMouseEnter,
|
||||
onClick,
|
||||
empty,
|
||||
emptyContent,
|
||||
...rest
|
||||
} = item;
|
||||
|
||||
const baseStyle = {
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
padding: '8px 16px',
|
||||
cursor: disabled ? 'not-allowed' : 'pointer',
|
||||
backgroundColor: focused ? 'var(--semi-color-fill-0)' : 'transparent',
|
||||
opacity: disabled ? 0.5 : 1,
|
||||
...(selected && {
|
||||
backgroundColor: 'var(--semi-color-primary-light-default)',
|
||||
}),
|
||||
'&:hover': {
|
||||
backgroundColor: !disabled && 'var(--semi-color-fill-1)'
|
||||
}
|
||||
};
|
||||
|
||||
const handleClick = () => {
|
||||
if (!disabled && onClick) {
|
||||
onClick();
|
||||
}
|
||||
};
|
||||
|
||||
const handleMouseEnter = (e) => {
|
||||
if (!disabled && onMouseEnter) {
|
||||
onMouseEnter(e);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
style={baseStyle}
|
||||
onClick={handleClick}
|
||||
onMouseEnter={handleMouseEnter}
|
||||
>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
|
||||
<Typography.Text strong type={disabled ? 'tertiary' : undefined}>
|
||||
{value}
|
||||
</Typography.Text>
|
||||
<Typography.Text type="secondary" size="small">
|
||||
{label}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
{item.ratio && renderRatio(item.ratio)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export function renderNumber(num) {
|
||||
if (num >= 1000000000) {
|
||||
return (num / 1000000000).toFixed(1) + 'B';
|
||||
@@ -59,6 +143,9 @@ export function renderNumber(num) {
|
||||
}
|
||||
|
||||
export function renderQuotaNumberWithDigit(num, digits = 2) {
|
||||
if (typeof num !== 'number' || isNaN(num)) {
|
||||
return 0;
|
||||
}
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
num = num.toFixed(digits);
|
||||
if (displayInCurrency) {
|
||||
@@ -340,7 +427,7 @@ export const modelColorMap = {
|
||||
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
|
||||
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
|
||||
'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色
|
||||
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃<EFBFBD><EFBFBD><EFBFBD>
|
||||
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃
|
||||
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
|
||||
'gpt-4': 'rgb(135,206,235)', // 天蓝色
|
||||
// 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
|
||||
@@ -363,7 +450,7 @@ export const modelColorMap = {
|
||||
'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红
|
||||
'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别)
|
||||
'text-moderation-latest': 'rgb(255,130,171)', // 强粉色
|
||||
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(<EFBFBD><EFBFBD><EFBFBD>Babbage相同,表示同一类功能)
|
||||
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(与Babbage相同,表示同一类功能)
|
||||
'tts-1': 'rgb(255,140,0)', // 深橙色
|
||||
'tts-1-1106': 'rgb(255,165,0)', // 橙色
|
||||
'tts-1-hd': 'rgb(255,215,0)', // 金色
|
||||
|
||||
@@ -49,8 +49,18 @@ export async function copy(text) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} catch (e) {
|
||||
okay = false;
|
||||
console.error(e);
|
||||
try {
|
||||
// 构建input 执行 复制命令
|
||||
var _input = window.document.createElement("input");
|
||||
_input.value = text;
|
||||
window.document.body.appendChild(_input);
|
||||
_input.select();
|
||||
window.document.execCommand("Copy");
|
||||
window.document.body.removeChild(_input);
|
||||
} catch (e) {
|
||||
okay = false;
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return okay;
|
||||
}
|
||||
@@ -180,6 +190,9 @@ export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') {
|
||||
let month = (date.getMonth() + 1).toString();
|
||||
let day = date.getDate().toString();
|
||||
let hour = date.getHours().toString();
|
||||
if (day === '24') {
|
||||
console.log("timestamp", timestamp);
|
||||
}
|
||||
if (month.length === 1) {
|
||||
month = '0' + month;
|
||||
}
|
||||
|
||||
@@ -546,8 +546,8 @@
|
||||
"是否用ID排序": "Whether to sort by ID",
|
||||
"确定?": "Sure?",
|
||||
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
|
||||
"开启批量删除": "Enable batch selection",
|
||||
"是否开启批量删除": "Whether to enable batch selection",
|
||||
"开启批量操作": "Enable batch selection",
|
||||
"是否开启批量操作": "Whether to enable batch selection",
|
||||
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
|
||||
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
|
||||
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
|
||||
|
||||
@@ -430,6 +430,8 @@
|
||||
"一小时后过期": "Expires after one hour",
|
||||
"一分钟后过期": "Expires after one minute",
|
||||
"创建新的令牌": "Create New Token",
|
||||
"令牌分组,默认为用户的分组": "Token group, default is the your's group",
|
||||
"IP白名单(请勿过度信任此功能)": "IP whitelist (do not overly trust this function)",
|
||||
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
|
||||
"设为无限额度": "Set to unlimited quota",
|
||||
"更新令牌信息": "Update Token Information",
|
||||
@@ -546,8 +548,8 @@
|
||||
"是否用ID排序": "Whether to sort by ID",
|
||||
"确定?": "Sure?",
|
||||
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
|
||||
"开启批量删除": "Enable batch selection",
|
||||
"是否开启批量删除": "Whether to enable batch selection",
|
||||
"开启批量操作": "Enable batch selection",
|
||||
"是否开启批量操作": "Whether to enable batch selection",
|
||||
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
|
||||
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
|
||||
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
|
||||
@@ -866,8 +868,8 @@
|
||||
"请选择模式": "Please select mode",
|
||||
"图片代理方式": "Picture agency method",
|
||||
"用于替换 https://cdn.discordapp.com 的域名": "The domain name used to replace https://cdn.discordapp.com",
|
||||
"一个月": "a month",
|
||||
"一天": "one day",
|
||||
"一个月": "A month",
|
||||
"一天": "One day",
|
||||
"令牌渠道分组选择": "Token channel grouping selection",
|
||||
"只可使用对应分组包含的模型。": "Only models contained in the corresponding group can be used.",
|
||||
"渠道分组": "Channel grouping",
|
||||
@@ -876,7 +878,7 @@
|
||||
"启用模型限制(非必要,不建议启用)": "Enable model restrictions (not necessary, not recommended)",
|
||||
"秒": "Second",
|
||||
"更新令牌后需等待几分钟生效": "It will take a few minutes to take effect after updating the token.",
|
||||
"一小时": "one hour",
|
||||
"一小时": "One hour",
|
||||
"新建数量": "New quantity",
|
||||
"加载失败,请稍后重试": "Loading failed, please try again later",
|
||||
"未设置": "Not set",
|
||||
@@ -1234,5 +1236,9 @@
|
||||
"应用更改": "Apply changes",
|
||||
"更多": "Expand more",
|
||||
"个模型": "models",
|
||||
"可用模型": "Available models"
|
||||
"可用模型": "Available models",
|
||||
"时间范围": "Time range",
|
||||
"批量设置标签": "Batch set tag",
|
||||
"请输入要设置的标签名称": "Please enter the tag name to be set",
|
||||
"请输入标签名称": "Please enter the tag name"
|
||||
}
|
||||
@@ -193,14 +193,16 @@ const EditChannel = (props) => {
|
||||
|
||||
|
||||
const fetchUpstreamModelList = async (name) => {
|
||||
if (inputs['type'] !== 1) {
|
||||
showError(t('仅支持 OpenAI 接口格式'));
|
||||
return;
|
||||
}
|
||||
// if (inputs['type'] !== 1) {
|
||||
// showError(t('仅支持 OpenAI 接口格式'));
|
||||
// return;
|
||||
// }
|
||||
setLoading(true);
|
||||
const models = inputs['models'] || [];
|
||||
let err = false;
|
||||
|
||||
if (isEdit) {
|
||||
// 如果是编辑模式,使用已有的channel id获取模型列表
|
||||
const res = await API.get('/api/channel/fetch_models/' + channelId);
|
||||
if (res.data && res.data?.success) {
|
||||
models.push(...res.data.data);
|
||||
@@ -208,30 +210,29 @@ const EditChannel = (props) => {
|
||||
err = true;
|
||||
}
|
||||
} else {
|
||||
// 如果是新建模式,通过后端代理获取模型列表
|
||||
if (!inputs?.['key']) {
|
||||
showError(t('请填写密钥'));
|
||||
err = true;
|
||||
} else {
|
||||
try {
|
||||
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
|
||||
|
||||
const url = `https://${host.hostname}/v1/models`;
|
||||
const key = inputs['key'];
|
||||
const res = await axios.get(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${key}`
|
||||
}
|
||||
const res = await API.post('/api/channel/fetch_models', {
|
||||
base_url: inputs['base_url'],
|
||||
key: inputs['key']
|
||||
});
|
||||
if (res.data) {
|
||||
models.push(...res.data.data.map((model) => model.id));
|
||||
|
||||
if (res.data && res.data.success) {
|
||||
models.push(...res.data.data);
|
||||
} else {
|
||||
err = true;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models:', error);
|
||||
err = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!err) {
|
||||
handleInputChange(name, Array.from(new Set(models)));
|
||||
showSuccess(t('获取模型列表成功'));
|
||||
@@ -638,7 +639,7 @@ const EditChannel = (props) => {
|
||||
{inputs.type === 21 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>知识库 ID:</Typography.Text>
|
||||
<Typography.Text strong><EFBFBD><EFBFBD>识库 ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
label="知识库 ID"
|
||||
|
||||
@@ -143,8 +143,7 @@ const Detail = (props) => {
|
||||
content: [
|
||||
{
|
||||
key: (datum) => datum['Model'],
|
||||
value: (datum) =>
|
||||
renderQuotaNumberWithDigit(parseFloat(datum['Usage']), 4),
|
||||
value: (datum) => renderQuota(datum['rawQuota'] || 0, 4),
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -152,22 +151,28 @@ const Detail = (props) => {
|
||||
content: [
|
||||
{
|
||||
key: (datum) => datum['Model'],
|
||||
value: (datum) => datum['Usage'],
|
||||
value: (datum) => datum['rawQuota'] || 0,
|
||||
},
|
||||
],
|
||||
updateContent: (array) => {
|
||||
array.sort((a, b) => b.value - a.value);
|
||||
let sum = 0;
|
||||
for (let i = 0; i < array.length; i++) {
|
||||
sum += parseFloat(array[i].value);
|
||||
array[i].value = renderQuotaNumberWithDigit(
|
||||
parseFloat(array[i].value),
|
||||
4,
|
||||
);
|
||||
if (array[i].key == "其他") {
|
||||
continue;
|
||||
}
|
||||
let value = parseFloat(array[i].value);
|
||||
if (isNaN(value)) {
|
||||
value = 0;
|
||||
}
|
||||
if (array[i].datum && array[i].datum.TimeSum) {
|
||||
sum = array[i].datum.TimeSum;
|
||||
}
|
||||
array[i].value = renderQuota(value, 4);
|
||||
}
|
||||
array.unshift({
|
||||
key: t('总计'),
|
||||
value: renderQuotaNumberWithDigit(sum, 4),
|
||||
value: renderQuota(sum, 4),
|
||||
});
|
||||
return array;
|
||||
},
|
||||
@@ -212,19 +217,8 @@ const Detail = (props) => {
|
||||
created_at: now.getTime() / 1000,
|
||||
});
|
||||
}
|
||||
// 根据dataExportDefaultTime重制时间粒度
|
||||
let timeGranularity = 3600;
|
||||
if (dataExportDefaultTime === 'day') {
|
||||
timeGranularity = 86400;
|
||||
} else if (dataExportDefaultTime === 'week') {
|
||||
timeGranularity = 604800;
|
||||
}
|
||||
// sort created_at
|
||||
data.sort((a, b) => a.created_at - b.created_at);
|
||||
data.forEach((item) => {
|
||||
item['created_at'] =
|
||||
Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
|
||||
});
|
||||
updateChartData(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -250,14 +244,14 @@ const Detail = (props) => {
|
||||
let uniqueModels = new Set();
|
||||
let totalTokens = 0;
|
||||
|
||||
// 收集所有唯一的模型名称和时间点
|
||||
let uniqueTimes = new Set();
|
||||
// 收集所有唯一的模型名称
|
||||
data.forEach(item => {
|
||||
uniqueModels.add(item.model_name);
|
||||
uniqueTimes.add(timestamp2string1(item.created_at, dataExportDefaultTime));
|
||||
totalTokens += item.token_used;
|
||||
totalQuota += item.quota;
|
||||
totalTimes += item.count;
|
||||
});
|
||||
|
||||
|
||||
// 处理颜色映射
|
||||
const newModelColors = {};
|
||||
Array.from(uniqueModels).forEach((modelName) => {
|
||||
@@ -267,56 +261,82 @@ const Detail = (props) => {
|
||||
});
|
||||
setModelColors(newModelColors);
|
||||
|
||||
// 处理饼图数据
|
||||
for (let item of data) {
|
||||
totalQuota += item.quota;
|
||||
totalTimes += item.count;
|
||||
|
||||
let pieItem = newPieData.find((it) => it.type === item.model_name);
|
||||
if (pieItem) {
|
||||
pieItem.value += item.count;
|
||||
} else {
|
||||
newPieData.push({
|
||||
type: item.model_name,
|
||||
value: item.count,
|
||||
// 按时间和模型聚合数据
|
||||
let aggregatedData = new Map();
|
||||
data.forEach(item => {
|
||||
const timeKey = timestamp2string1(item.created_at, dataExportDefaultTime);
|
||||
const modelKey = item.model_name;
|
||||
const key = `${timeKey}-${modelKey}`;
|
||||
|
||||
if (!aggregatedData.has(key)) {
|
||||
aggregatedData.set(key, {
|
||||
time: timeKey,
|
||||
model: modelKey,
|
||||
quota: 0,
|
||||
count: 0
|
||||
});
|
||||
}
|
||||
|
||||
const existing = aggregatedData.get(key);
|
||||
existing.quota += item.quota;
|
||||
existing.count += item.count;
|
||||
});
|
||||
|
||||
// 处理饼图数据
|
||||
let modelTotals = new Map();
|
||||
for (let [_, value] of aggregatedData) {
|
||||
if (!modelTotals.has(value.model)) {
|
||||
modelTotals.set(value.model, 0);
|
||||
}
|
||||
modelTotals.set(value.model, modelTotals.get(value.model) + value.count);
|
||||
}
|
||||
|
||||
// 处理柱状图数据
|
||||
let timePoints = Array.from(uniqueTimes);
|
||||
newPieData = Array.from(modelTotals).map(([model, count]) => ({
|
||||
type: model,
|
||||
value: count
|
||||
}));
|
||||
|
||||
// 生成时间点序列
|
||||
let timePoints = Array.from(new Set([...aggregatedData.values()].map(d => d.time)));
|
||||
if (timePoints.length < 7) {
|
||||
// 根据时间粒度生成合适的时间点
|
||||
const generateTimePoints = () => {
|
||||
let lastTime = Math.max(...data.map(item => item.created_at));
|
||||
let points = [];
|
||||
let interval = dataExportDefaultTime === 'hour' ? 3600
|
||||
const lastTime = Math.max(...data.map(item => item.created_at));
|
||||
const interval = dataExportDefaultTime === 'hour' ? 3600
|
||||
: dataExportDefaultTime === 'day' ? 86400
|
||||
: 604800;
|
||||
|
||||
for (let i = 0; i < 7; i++) {
|
||||
points.push(timestamp2string1(lastTime - (i * interval), dataExportDefaultTime));
|
||||
}
|
||||
return points.reverse();
|
||||
};
|
||||
|
||||
timePoints = generateTimePoints();
|
||||
|
||||
timePoints = Array.from({length: 7}, (_, i) =>
|
||||
timestamp2string1(lastTime - (6-i) * interval, dataExportDefaultTime)
|
||||
);
|
||||
}
|
||||
|
||||
// 为每个时间点和模型生成数据
|
||||
// 生成柱状图数据
|
||||
timePoints.forEach(time => {
|
||||
Array.from(uniqueModels).forEach(model => {
|
||||
let existingData = data.find(item =>
|
||||
timestamp2string1(item.created_at, dataExportDefaultTime) === time &&
|
||||
item.model_name === model
|
||||
);
|
||||
|
||||
newLineData.push({
|
||||
// 为每个时间点收集所有模型的数据
|
||||
let timeData = Array.from(uniqueModels).map(model => {
|
||||
const key = `${time}-${model}`;
|
||||
const aggregated = aggregatedData.get(key);
|
||||
return {
|
||||
Time: time,
|
||||
Model: model,
|
||||
Usage: existingData ? parseFloat(getQuotaWithUnit(existingData.quota)) : 0
|
||||
});
|
||||
rawQuota: aggregated?.quota || 0,
|
||||
Usage: aggregated?.quota ? getQuotaWithUnit(aggregated.quota, 4) : 0
|
||||
};
|
||||
});
|
||||
|
||||
// 计算该时间点的总计
|
||||
const timeSum = timeData.reduce((sum, item) => sum + item.rawQuota, 0);
|
||||
|
||||
// 按照 rawQuota 从大到小排序
|
||||
timeData.sort((a, b) => b.rawQuota - a.rawQuota);
|
||||
|
||||
// 为每个数据点添加该时间的总计
|
||||
timeData = timeData.map(item => ({
|
||||
...item,
|
||||
TimeSum: timeSum
|
||||
}));
|
||||
|
||||
// 将排序后的数据添加到 newLineData
|
||||
newLineData.push(...timeData);
|
||||
});
|
||||
|
||||
// 排序
|
||||
|
||||
@@ -2,11 +2,12 @@ import React, { useCallback, useContext, useEffect, useState } from 'react';
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../../context/User/index.js';
|
||||
import { API, getUserIdFromLocalStorage, showError } from '../../helpers/index.js';
|
||||
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography, Button } from '@douyinfe/semi-ui';
|
||||
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography, Button, Highlight } from '@douyinfe/semi-ui';
|
||||
import { SSE } from 'sse';
|
||||
import { IconSetting } from '@douyinfe/semi-icons';
|
||||
import { StyleContext } from '../../context/Style/index.js';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { renderGroupOption } from '../../helpers/render.js';
|
||||
|
||||
const roleInfo = {
|
||||
user: {
|
||||
@@ -97,15 +98,17 @@ const Playground = () => {
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
|
||||
label: info.desc,
|
||||
value: group,
|
||||
ratio: info.ratio
|
||||
}));
|
||||
|
||||
if (localGroupOptions.length === 0) {
|
||||
localGroupOptions = [{
|
||||
label: t('用户分组'),
|
||||
value: '',
|
||||
ratio: 1
|
||||
}];
|
||||
} else {
|
||||
const localUser = JSON.parse(localStorage.getItem('user'));
|
||||
@@ -326,12 +329,9 @@ const Playground = () => {
|
||||
}}
|
||||
value={inputs.group}
|
||||
autoComplete='new-password'
|
||||
optionList={groups.map((group) => ({
|
||||
...group,
|
||||
label: styleState.isMobile && group.label.length > 16
|
||||
? group.label.substring(0, 16) + '...'
|
||||
: group.label,
|
||||
}))}
|
||||
optionList={groups}
|
||||
renderOptionItem={renderGroupOption}
|
||||
style={{ width: '100%' }}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>{t('模型')}:</Typography.Text>
|
||||
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
showSuccess,
|
||||
timestamp2string,
|
||||
} from '../../helpers';
|
||||
import { renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import { renderGroupOption, renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import {
|
||||
AutoComplete,
|
||||
Banner,
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
} from '@douyinfe/semi-ui';
|
||||
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
|
||||
import { Divider } from 'semantic-ui-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const EditToken = (props) => {
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
@@ -52,6 +53,7 @@ const EditToken = (props) => {
|
||||
const [models, setModels] = useState([]);
|
||||
const [groups, setGroups] = useState([]);
|
||||
const navigate = useNavigate();
|
||||
const { t } = useTranslation();
|
||||
const handleInputChange = (name, value) => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
@@ -87,7 +89,7 @@ const EditToken = (props) => {
|
||||
}));
|
||||
setModels(localModelOptions);
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -95,15 +97,14 @@ const EditToken = (props) => {
|
||||
let res = await API.get(`/api/user/self/groups`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
// return data is a map, key is group name, value is group description
|
||||
// label is group description, value is group name
|
||||
let localGroupOptions = Object.keys(data).map((group) => ({
|
||||
label: data[group],
|
||||
value: group,
|
||||
}));
|
||||
setGroups(localGroupOptions);
|
||||
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
|
||||
label: info.desc,
|
||||
value: group,
|
||||
ratio: info.ratio
|
||||
}));
|
||||
setGroups(localGroupOptions);
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -176,7 +177,7 @@ const EditToken = (props) => {
|
||||
if (localInputs.expired_time !== -1) {
|
||||
let time = Date.parse(localInputs.expired_time);
|
||||
if (isNaN(time)) {
|
||||
showError('过期时间格式错误!');
|
||||
showError(t('过期时间格式错误!'));
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
@@ -189,11 +190,11 @@ const EditToken = (props) => {
|
||||
});
|
||||
const { success, message } = res.data;
|
||||
if (success) {
|
||||
showSuccess('令牌更新成功!');
|
||||
showSuccess(t('令牌更新成功!'));
|
||||
props.refresh();
|
||||
props.handleClose();
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
}
|
||||
} else {
|
||||
// 处理新增多个令牌的情况
|
||||
@@ -209,7 +210,7 @@ const EditToken = (props) => {
|
||||
if (localInputs.expired_time !== -1) {
|
||||
let time = Date.parse(localInputs.expired_time);
|
||||
if (isNaN(time)) {
|
||||
showError('过期时间格式错误!');
|
||||
showError(t('过期时间格式错误!'));
|
||||
setLoading(false);
|
||||
break;
|
||||
}
|
||||
@@ -222,14 +223,14 @@ const EditToken = (props) => {
|
||||
if (success) {
|
||||
successCount++;
|
||||
} else {
|
||||
showError(message);
|
||||
showError(t(message));
|
||||
break; // 如果创建失败,终止循环
|
||||
}
|
||||
}
|
||||
|
||||
if (successCount > 0) {
|
||||
showSuccess(
|
||||
`${successCount}个令牌创建成功,请在列表页面点击复制获取令牌!`,
|
||||
t('令牌创建成功,请在列表页面点击复制获取令牌!')
|
||||
);
|
||||
props.refresh();
|
||||
props.handleClose();
|
||||
@@ -245,7 +246,7 @@ const EditToken = (props) => {
|
||||
<SideSheet
|
||||
placement={isEdit ? 'right' : 'left'}
|
||||
title={
|
||||
<Title level={3}>{isEdit ? '更新令牌信息' : '创建新的令牌'}</Title>
|
||||
<Title level={3}>{isEdit ? t('更新令牌信息') : t('创建新的令牌')}</Title>
|
||||
}
|
||||
headerStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
|
||||
bodyStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
|
||||
@@ -254,7 +255,7 @@ const EditToken = (props) => {
|
||||
<div style={{ display: 'flex', justifyContent: 'flex-end' }}>
|
||||
<Space>
|
||||
<Button theme='solid' size={'large'} onClick={submit}>
|
||||
提交
|
||||
{t('提交')}
|
||||
</Button>
|
||||
<Button
|
||||
theme='solid'
|
||||
@@ -262,7 +263,7 @@ const EditToken = (props) => {
|
||||
type={'tertiary'}
|
||||
onClick={handleCancel}
|
||||
>
|
||||
取消
|
||||
{t('取消')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
@@ -274,9 +275,9 @@ const EditToken = (props) => {
|
||||
<Spin spinning={loading}>
|
||||
<Input
|
||||
style={{ marginTop: 20 }}
|
||||
label='名称'
|
||||
label={t('名称')}
|
||||
name='name'
|
||||
placeholder={'请输入名称'}
|
||||
placeholder={t('请输入名称')}
|
||||
onChange={(value) => handleInputChange('name', value)}
|
||||
value={name}
|
||||
autoComplete='new-password'
|
||||
@@ -284,9 +285,9 @@ const EditToken = (props) => {
|
||||
/>
|
||||
<Divider />
|
||||
<DatePicker
|
||||
label='过期时间'
|
||||
label={t('过期时间')}
|
||||
name='expired_time'
|
||||
placeholder={'请选择过期时间'}
|
||||
placeholder={t('请选择过期时间')}
|
||||
onChange={(value) => handleInputChange('expired_time', value)}
|
||||
value={expired_time}
|
||||
autoComplete='new-password'
|
||||
@@ -300,7 +301,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(0, 0, 0, 0);
|
||||
}}
|
||||
>
|
||||
永不过期
|
||||
{t('永不过期')}
|
||||
</Button>
|
||||
<Button
|
||||
type={'tertiary'}
|
||||
@@ -308,7 +309,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(0, 0, 1, 0);
|
||||
}}
|
||||
>
|
||||
一小时
|
||||
{t('一小时')}
|
||||
</Button>
|
||||
<Button
|
||||
type={'tertiary'}
|
||||
@@ -316,7 +317,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(1, 0, 0, 0);
|
||||
}}
|
||||
>
|
||||
一个月
|
||||
{t('一个月')}
|
||||
</Button>
|
||||
<Button
|
||||
type={'tertiary'}
|
||||
@@ -324,7 +325,7 @@ const EditToken = (props) => {
|
||||
setExpiredTime(0, 1, 0, 0);
|
||||
}}
|
||||
>
|
||||
一天
|
||||
{t('一天')}
|
||||
</Button>
|
||||
</Space>
|
||||
</div>
|
||||
@@ -332,17 +333,15 @@ const EditToken = (props) => {
|
||||
<Divider />
|
||||
<Banner
|
||||
type={'warning'}
|
||||
description={
|
||||
'注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。'
|
||||
}
|
||||
description={t('注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。')}
|
||||
></Banner>
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Typography.Text>{`额度${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
|
||||
<Typography.Text>{`${t('额度')}${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
|
||||
</div>
|
||||
<AutoComplete
|
||||
style={{ marginTop: 8 }}
|
||||
name='remain_quota'
|
||||
placeholder={'请输入额度'}
|
||||
placeholder={t('请输入额度')}
|
||||
onChange={(value) => handleInputChange('remain_quota', value)}
|
||||
value={remain_quota}
|
||||
autoComplete='new-password'
|
||||
@@ -362,22 +361,22 @@ const EditToken = (props) => {
|
||||
{!isEdit && (
|
||||
<>
|
||||
<div style={{ marginTop: 20 }}>
|
||||
<Typography.Text>新建数量</Typography.Text>
|
||||
<Typography.Text>{t('新建数量')}</Typography.Text>
|
||||
</div>
|
||||
<AutoComplete
|
||||
style={{ marginTop: 8 }}
|
||||
label='数量'
|
||||
placeholder={'请选择或输入创建令牌的数量'}
|
||||
label={t('数量')}
|
||||
placeholder={t('请选择或输入创建令牌的数量')}
|
||||
onChange={(value) => handleTokenCountChange(value)}
|
||||
onSelect={(value) => handleTokenCountChange(value)}
|
||||
value={tokenCount.toString()}
|
||||
autoComplete='off'
|
||||
type='number'
|
||||
data={[
|
||||
{ value: 10, label: '10个' },
|
||||
{ value: 20, label: '20个' },
|
||||
{ value: 30, label: '30个' },
|
||||
{ value: 100, label: '100个' },
|
||||
{ value: 10, label: t('10个') },
|
||||
{ value: 20, label: t('20个') },
|
||||
{ value: 30, label: t('30个') },
|
||||
{ value: 100, label: t('100个') },
|
||||
]}
|
||||
disabled={unlimited_quota}
|
||||
/>
|
||||
@@ -392,17 +391,17 @@ const EditToken = (props) => {
|
||||
setUnlimitedQuota();
|
||||
}}
|
||||
>
|
||||
{unlimited_quota ? '取消无限额度' : '设为无限额度'}
|
||||
{unlimited_quota ? t('取消无限额度') : t('设为无限额度')}
|
||||
</Button>
|
||||
</div>
|
||||
<Divider />
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text>IP白名单(请勿过度信任此功能)</Typography.Text>
|
||||
<Typography.Text>{t('IP白名单(请勿过度信任此功能)')}</Typography.Text>
|
||||
</div>
|
||||
<TextArea
|
||||
label='IP白名单'
|
||||
label={t('IP白名单')}
|
||||
name='allow_ips'
|
||||
placeholder={'允许的IP,一行一个'}
|
||||
placeholder={t('允许的IP,一行一个')}
|
||||
onChange={(value) => {
|
||||
handleInputChange('allow_ips', value);
|
||||
}}
|
||||
@@ -417,16 +416,15 @@ const EditToken = (props) => {
|
||||
onChange={(e) =>
|
||||
handleInputChange('model_limits_enabled', e.target.checked)
|
||||
}
|
||||
></Checkbox>
|
||||
<Typography.Text>
|
||||
启用模型限制(非必要,不建议启用)
|
||||
</Typography.Text>
|
||||
>
|
||||
{t('启用模型限制(非必要,不建议启用)')}
|
||||
</Checkbox>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
<Select
|
||||
style={{ marginTop: 8 }}
|
||||
placeholder={'请选择该渠道所支持的模型'}
|
||||
placeholder={t('请选择该渠道所支持的模型')}
|
||||
name='models'
|
||||
required
|
||||
multiple
|
||||
@@ -440,25 +438,27 @@ const EditToken = (props) => {
|
||||
disabled={!model_limits_enabled}
|
||||
/>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text>令牌分组,默认为用户的分组</Typography.Text>
|
||||
<Typography.Text>{t('令牌分组,默认为用户的分组')}</Typography.Text>
|
||||
</div>
|
||||
{groups.length > 0 ?
|
||||
<Select
|
||||
style={{ marginTop: 8 }}
|
||||
placeholder={'令牌分组,默认为用户的分组'}
|
||||
placeholder={t('令牌分组,默认为用户的分组')}
|
||||
name='gruop'
|
||||
required
|
||||
selection
|
||||
onChange={(value) => {
|
||||
handleInputChange('group', value);
|
||||
}}
|
||||
position={'topLeft'}
|
||||
renderOptionItem={renderGroupOption}
|
||||
value={inputs.group}
|
||||
autoComplete='new-password'
|
||||
optionList={groups}
|
||||
/>:
|
||||
<Select
|
||||
style={{ marginTop: 8 }}
|
||||
placeholder={'管理员未设置用户可选分组'}
|
||||
placeholder={t('管理员未设置用户可选分组')}
|
||||
name='gruop'
|
||||
disabled={true}
|
||||
/>
|
||||
|
||||
Reference in New Issue
Block a user