Compare commits

..

68 Commits

Author SHA1 Message Date
1808837298@qq.com
ef4c1a2e48 fix: retry prompt tokens 2025-01-02 16:33:00 +08:00
Calcium-Ion
ba1aad8ac4 Merge pull request #686 from delph1s/main
fix: try to fix pgsql #685
2025-01-02 00:17:02 +08:00
delph1s
42bf95bd54 fix: try to fix pgsql #685 2025-01-02 00:14:16 +08:00
Calcium-Ion
bf9a492f25 Update README.md 2024-12-31 22:19:37 +08:00
Calcium-Ion
16725d1226 Merge pull request #683 from iszcz/new512
Update channel-test.go
2024-12-31 20:22:57 +08:00
CalciumIon
e6ea5e59c0 fix: error page size opts 2024-12-31 15:51:15 +08:00
CalciumIon
4f196a62e1 feat: implement pagination and total count for redemptions API #386
- Updated GetAllRedemptions and SearchRedemptions functions to return total count along with paginated results.
- Modified API endpoints to accept page size as a parameter, enhancing flexibility in data retrieval.
- Adjusted RedemptionsTable component to support pagination and display total count, improving user experience.
- Ensured consistent handling of pagination across related components, including LogsTable and UsersTable.
2024-12-31 15:28:25 +08:00
CalciumIon
014fb7edab feat: enhance user search functionality with pagination support
- Updated SearchUsers function to include pagination parameters (startIdx and num) for improved user search results.
- Modified API response structure to return paginated data, including total user count and current page information.
- Adjusted UsersTable component to handle pagination and search parameters, ensuring a seamless user experience.
- Added internationalization support for new search functionality in the UI.
2024-12-31 15:02:59 +08:00
CalciumIon
be0b2f6a64 feat: enhance user management and pagination features #518
- Updated GetAllUsers function to return total user count along with paginated results, improving data handling in user retrieval.
- Modified GetAllUsers API endpoint to accept page size as a parameter, allowing for dynamic pagination.
- Enhanced UsersTable component to support customizable page sizes and improved pagination logic.
- Added error handling for empty username and password in AddUser component.
- Updated LogsTable component to display pagination information in a user-friendly format.
2024-12-31 14:52:55 +08:00
iszcz
687f07bc10 Update channel-test.go 2024-12-31 12:49:13 +08:00
CalciumIon
a7e5f1e509 fix: try to fix pgsql #682 2024-12-31 02:10:19 +08:00
CalciumIon
87d5e286d5 fix: try to fix pgsql #682 2024-12-31 02:06:30 +08:00
CalciumIon
b4f17543cb fix redis 2024-12-30 22:05:41 +08:00
CalciumIon
1eb706de7a docs: update README 2024-12-30 20:56:54 +08:00
CalciumIon
d13d81baba refactor: update group handling and rendering logic
- Changed the structure of usableGroups in GetUserGroups to store additional information (ratio and description) for each group.
- Introduced a new renderRatio function to visually represent group ratios with color coding.
- Updated the Playground and EditToken components to utilize the new group structure and rendering options.
- Enhanced the renderGroupOption function for better UI representation of group options.
- Fixed minor comments and improved code readability.
2024-12-30 19:51:00 +08:00
Calcium-Ion
65af1a4d10 Merge pull request #679 from kingxjs/main
fix: use document to build input fix copy command
2024-12-30 18:02:21 +08:00
Calcium-Ion
1ae0a3fb83 Merge pull request #677 from mageia/master
修复 PostgreSQL 中用户组查询错误
2024-12-30 18:01:51 +08:00
CalciumIon
fe2e8f1a42 Merge branch 'main'
# Conflicts:
#	model/user.go
2024-12-30 18:00:59 +08:00
Calcium-Ion
a5f7f8af29 Merge pull request #680 from Calcium-Ion/refactor_redis
Refactor redis
2024-12-30 17:55:07 +08:00
CalciumIon
2f01a2125f feat: enhance environment variable handling and security features 2024-12-30 17:24:19 +08:00
迷糊虫
e4f9787c16 使用原生document构建input再次尝试复制命令 2024-12-30 17:13:49 +08:00
CalciumIon
bb5e032dd2 refactor: token cache logic 2024-12-30 17:10:48 +08:00
Mageia
304c92ceab 修复 PostgreSQL 中用户组查询错误
- 修复 model/user.go 中的 SQL 查询,使用双引号将 group 列名括起来
- 对于 PostgreSQL 数据库,`group` 是保留关键字,需要用双引号括起来避免语法错误。该修改确保了代码在 PostgreSQL 和其他数据库(如 MySQL)中都能正常工作。
2024-12-30 10:23:55 +08:00
Calcium-Ion
05874dcca5 Merge pull request #676 from Calcium-Ion/refactor_redis
refactor: user cache logic
2024-12-29 17:55:52 +08:00
CalciumIon
ca8b7ed1c3 refactor: remove redundant group column handling in user queries 2024-12-29 17:02:30 +08:00
CalciumIon
ed435e5c8f refactor: user cache logic 2024-12-29 16:50:26 +08:00
Calcium-Ion
a1b864bc5e Merge pull request #674 from Yan-Zero/main
fix: Gemini 函数调用的文本转义,以及其他文件类型的 Base64 支持
2024-12-29 13:11:02 +08:00
Yan
2a15dfccea fix: Gemini 其他文件类型的支持(Base64URL) 2024-12-29 10:11:39 +08:00
Yan
9e5a7ed541 fix: Gemini 函数调用的文本转义 2024-12-29 06:11:44 +08:00
CalciumIon
65d1cde8fb fix: playground request_start_time 2024-12-29 01:03:02 +08:00
CalciumIon
8f4a2df5ee fix: prevent setting models to null in PersonalSetting component 2024-12-29 00:24:02 +08:00
CalciumIon
2b38e8ed8d feat: add multi-file type support for Gemini and Claude
- Add file data DTO for structured file handling
- Implement file decoder service
- Update Claude and Gemini relay channels to handle various file types
- Reorganize worker service to cf_worker for clarity
- Update token counter and image service for new file types
2024-12-29 00:00:24 +08:00
CalciumIon
d75ecfc63e chore: update language in index.html to Chinese 2024-12-28 20:43:26 +08:00
Calcium-Ion
91b777f33f Merge pull request #673 from Yan-Zero/main
fix: 转义 Gemini 工具调用中的反斜杠
2024-12-28 19:49:31 +08:00
Yan
72dc54309c fix: 转义 Gemini 工具调用中的反斜杠 2024-12-28 18:29:48 +08:00
Calcium-Ion
458dd1bd9d Merge pull request #672 from Yan-Zero/main
fix: add index in the tool calls when chat by stream (gemini)
2024-12-28 18:20:33 +08:00
Yan
38cff317a0 fix: add index in the tool calls when chat by stream (gemini) 2024-12-28 17:56:31 +08:00
CalciumIon
c8614f9890 refactor: Playground controller 2024-12-28 16:47:56 +08:00
CalciumIon
10d896aa7f refactor: streamline log processing by introducing formatUserLogs function 2024-12-28 16:40:29 +08:00
CalciumIon
118eb362c4 refactor: enhance log retrieval and user interaction in LogsTable component 2024-12-28 15:34:28 +08:00
CalciumIon
52c023a1dd fix #663 2024-12-27 21:59:05 +08:00
CalciumIon
1cef91a741 fix: prevent duplicate models in user group retrieval 2024-12-27 21:25:44 +08:00
CalciumIon
77861e6440 refactor: improve user group handling and add GetUserUsableGroups function
- Introduced a new function `GetUserUsableGroupsCopy` to return a copy of user usable groups.
- Updated `GetUserUsableGroups` to utilize the new function for better encapsulation.
- Changed variable names from `UserUsableGroups` to `userUsableGroups` for consistency.
- Enhanced `GetUserUsableGroups` logic to ensure it returns a copy of the groups, preventing unintended modifications.
2024-12-27 21:19:22 +08:00
CalciumIon
5f082d72bb update dockerignore 2024-12-27 20:49:58 +08:00
CalciumIon
0fd0e5d309 fix: oauth bind 2024-12-27 18:32:11 +08:00
CalciumIon
d2297d2723 feat: update o1 default token encoder 2024-12-27 15:03:10 +08:00
CalciumIon
62ae46b552 feat: support azure stream_options 2024-12-26 22:51:06 +08:00
CalciumIon
0b1354ed51 update model ratio 2024-12-26 16:03:22 +08:00
Calcium-Ion
132c71390c Merge pull request #661 from tenacioustommy/fix-title-schema
fix delete title schema
2024-12-26 14:27:07 +08:00
Calcium-Ion
bb3deb7b93 Merge pull request #662 from xqx333/main
fix 重试过程多次获取图片
2024-12-26 14:26:50 +08:00
CalciumIon
f92d96e298 fix: update render function for quota display in Detail page 2024-12-26 14:25:44 +08:00
xqx333
c86762b656 Update relay-text.go
在上下文中存入promptTokens,避免重试过程重复计算
2024-12-26 02:00:04 +08:00
tenacious
3409d7a6b6 fix delete title schema 2024-12-26 00:24:45 +08:00
CalciumIon
bfba4866a5 fix: validate number input in renderQuotaNumberWithDigit and improve data handling in Detail page
- Added input validation to ensure that the `num` parameter in `renderQuotaNumberWithDigit` is a valid number, returning 0 for invalid inputs.
- Updated the `Detail` component to use `datum['rawQuota']` instead of `datum['Usage']` for rendering quota values, ensuring more accurate data representation.
- Enhanced data aggregation logic to handle cases where quota values may be missing or invalid, improving overall data integrity in charts and tables.
- Removed unnecessary time granularity calculations and streamlined the data processing for better performance.
2024-12-25 23:16:35 +08:00
CalciumIon
4fc1fe318e refactor: migrate group ratio and user usable groups logic to new setting package
- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services.
- Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability.
- Updated related functions to ensure consistent behavior with the new setting package integration.
2024-12-25 19:31:12 +08:00
CalciumIon
b3576f24ef fix typo 2024-12-25 18:44:45 +08:00
CalciumIon
ed4d26fc9e fix: update MaxCompletionTokens for model prefix handling in buildTestRequest function 2024-12-25 17:55:20 +08:00
CalciumIon
ba56e2e8ca fix: correct user retrieval in GetPricing function 2024-12-25 14:29:52 +08:00
CalciumIon
7c20e6d047 fix: resolve pricing calculation issue (#659) 2024-12-25 14:26:43 +08:00
CalciumIon
72d6898eb5 feat: Implement batch tagging functionality for channels
- Added a new endpoint to batch set tags for multiple channels, allowing users to update tags efficiently.
- Introduced a new `BatchSetChannelTag` function in the controller to handle incoming requests and validate parameters.
- Updated the `BatchSetChannelTag` method in the model to manage database transactions and ensure data integrity during tag updates.
- Enhanced the ChannelsTable component in the frontend to support batch tag setting, including UI elements for user interaction.
- Updated localization files to include new translation keys related to batch operations and tag settings.
2024-12-25 14:19:00 +08:00
CalciumIon
f2c9388139 fix: update searchUsers function to include searchKeyword and searchGroup parameters 2024-12-25 13:44:55 +08:00
Calcium-Ion
aaf5cecefd Merge pull request #656 from Yan-Zero/main
fix: gemini function call
2024-12-25 13:38:34 +08:00
CalciumIon
fe2165ace6 fix: #657 2024-12-24 22:30:05 +08:00
CalciumIon
3003d12a20 fix: get upstream models 2024-12-24 20:48:21 +08:00
Yan
a8a2195ab1 Merge branch 'Calcium-Ion:main' into main 2024-12-24 20:46:16 +08:00
Yan
d40e6ec25d fix: gemini func call 2024-12-24 20:46:02 +08:00
CalciumIon
8129aa76f9 feat: Enhance pricing functionality with user group support
- Updated the GetPricing function in the backend to include user group information, allowing for dynamic adjustment of group ratios based on the user's group.
- Implemented logic to filter group ratios based on the user's usable groups, improving the accuracy of pricing data returned.
- Modified the ModelPricing component to utilize the new usable group data, ensuring only relevant groups are displayed in the UI.
- Enhanced state management in the frontend to accommodate the new usable group information, improving user experience and data consistency.
2024-12-24 19:23:29 +08:00
CalciumIon
fb8595da18 feat: Update localization and enhance token editing functionality
- Added new translation keys for English localization in `en.json`, including "Token group, default is the your's group" and "IP whitelist (do not overly trust this function)".
- Refactored `EditToken.js` to utilize the `useTranslation` hook for improved internationalization, ensuring all user-facing strings are translatable.
- Updated error and success messages to use translation functions, enhancing user experience for non-English speakers.
- Improved UI elements to support localization, including labels, placeholders, and button texts, ensuring consistency across the token editing interface.
2024-12-24 18:40:18 +08:00
81 changed files with 4854 additions and 3931 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.github
.git
*.md
.vscode
.gitignore
Makefile

3
.gitignore vendored
View File

@@ -8,4 +8,5 @@ build
logs
web/dist
.env
one-api
one-api
.DS_Store

View File

@@ -82,6 +82,8 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
- `GEMINI_MODEL_MAP`: Specify Gemini model versions (v1/v1beta), format: "model:version", comma-separated
- `COHERE_SAFETY_SETTING`: Cohere model [safety settings](https://docs.cohere.com/docs/safety-modes#overview), options: `NONE`, `CONTEXTUAL`, `STRICT`, default `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`: Gemini model maximum image number, default `16`, set to `-1` to disable
- `MAX_FILE_DOWNLOAD_MB`: Maximum file download size in MB, default `20`
- `CRYPTO_SECRET`: Encryption key for encrypting database content
## Deployment
> [!TIP]
@@ -92,6 +94,10 @@ You can add custom models gpt-4-gizmo-* in channels. These are third-party model
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
### Multi-Server Deployment
- Must set `SESSION_SECRET` environment variable, otherwise login state will not be consistent across multiple servers.
- If using a public Redis, must set `CRYPTO_SECRET` environment variable, otherwise Redis content will not be able to be obtained in multi-server deployment.
### Requirements
- Local database (default): SQLite (Docker deployment must mount `/data` directory)
- Remote database: MySQL >= 5.7.8, PgSQL >= 9.6

View File

@@ -88,6 +88,8 @@
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
- `COHERE_SAFETY_SETTING`Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL``STRICT`,默认为 `NONE`
- `GEMINI_VISION_MAX_IMAGE_NUM`Gemini模型最大图片数量默认为 `16`,设置为 `-1` 则不限制。
- `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB默认为 `20`
- `CRYPTO_SECRET`:加密密钥,用于加密数据库内容。
## 部署
> [!TIP]
> 最新版Docker镜像`calciumion/new-api:latest`
@@ -97,6 +99,10 @@
> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR
> ```
### 多机部署
- 必须设置环境变量 `SESSION_SECRET`,否则会导致多机部署时登录状态不一致。
- 如果公用Redis必须设置 `CRYPTO_SECRET`否则会导致多机部署时Redis内容无法获取。
### 部署要求
- 本地数据库默认SQLiteDocker 部署默认使用 SQLite必须挂载 `/data` 目录到宿主机)
- 远程数据库MySQL 版本 >= 5.7.8PgSQL 版本 >= 9.6
@@ -153,15 +159,14 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
[对接文档](Suno.md)
## 界面截图
![796df8d287b7b7bd7853b2497e7df511](https://github.com/user-attachments/assets/255b5e97-2d3a-4434-b4fa-e922ad88ff5a)
![image](https://github.com/user-attachments/assets/a0dcd349-5df8-4dc8-9acf-ca272b239919)
![image](https://github.com/user-attachments/assets/c7d0f7e1-729c-43e2-ac7c-2cb73b0afc8e)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/user-attachments/assets/29f81de5-33fc-4fc5-a5ff-f9b54b653c7c)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
![image](https://github.com/user-attachments/assets/4fa53e18-d2c5-477a-9b26-b86e44c71e35)
## 交流群
<img src="https://github.com/user-attachments/assets/9ca0bc82-e057-4230-a28d-9f198fa022e3" width="200">

View File

@@ -30,6 +30,7 @@ var DefaultCollapseSidebar = false // default value of collapse sidebar
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var CryptoSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex

View File

@@ -1,6 +1,23 @@
package common
import "golang.org/x/crypto/bcrypt"
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"golang.org/x/crypto/bcrypt"
)
func GenerateHMACWithKey(key []byte, data string) string {
h := hmac.New(sha256.New, key)
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func GenerateHMAC(data string) string {
h := hmac.New(sha256.New, []byte(CryptoSecret))
h.Write([]byte(data))
return hex.EncodeToString(h.Sum(nil))
}
func Password2Hash(password string) (string, error) {
passwordBytes := []byte(password)

View File

@@ -22,7 +22,7 @@ func printHelp() {
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
}
func init() {
func LoadEnv() {
flag.Parse()
if *PrintVersion {
@@ -45,6 +45,11 @@ func init() {
SessionSecret = ss
}
}
if os.Getenv("CRYPTO_SECRET") != "" {
CryptoSecret = os.Getenv("CRYPTO_SECRET")
} else {
CryptoSecret = SessionSecret
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
}

View File

@@ -356,7 +356,7 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "o1-") {
if strings.HasPrefix(name, "o1") {
return 4
}
if name == "chatgpt-4o-latest" {

View File

@@ -2,9 +2,15 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
"errors"
"fmt"
"os"
"reflect"
"strconv"
"time"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
var RDB *redis.Client
@@ -56,39 +62,167 @@ func RedisGet(key string) (string, error) {
return RDB.Get(ctx, key).Result()
}
func RedisExpire(key string, expiration time.Duration) error {
ctx := context.Background()
return RDB.Expire(ctx, key, expiration).Err()
}
func RedisGetEx(key string, expiration time.Duration) (string, error) {
ctx := context.Background()
return RDB.GetSet(ctx, key, expiration).Result()
}
//func RedisExpire(key string, expiration time.Duration) error {
// ctx := context.Background()
// return RDB.Expire(ctx, key, expiration).Err()
//}
//
//func RedisGetEx(key string, expiration time.Duration) (string, error) {
// ctx := context.Background()
// return RDB.GetSet(ctx, key, expiration).Result()
//}
func RedisDel(key string) error {
ctx := context.Background()
return RDB.Del(ctx, key).Err()
}
func RedisDecrease(key string, value int64) error {
func RedisHDelObj(key string) error {
ctx := context.Background()
return RDB.HDel(ctx, key).Err()
}
func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
ctx := context.Background()
data := make(map[string]interface{})
// 使用反射遍历结构体字段
v := reflect.ValueOf(obj).Elem()
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
value := v.Field(i)
// Skip DeletedAt field
if field.Type.String() == "gorm.DeletedAt" {
continue
}
// 处理指针类型
if value.Kind() == reflect.Ptr {
if value.IsNil() {
data[field.Name] = ""
continue
}
value = value.Elem()
}
// 处理布尔类型
if value.Kind() == reflect.Bool {
data[field.Name] = strconv.FormatBool(value.Bool())
continue
}
// 其他类型直接转换为字符串
data[field.Name] = fmt.Sprintf("%v", value.Interface())
}
txn := RDB.TxPipeline()
txn.HSet(ctx, key, data)
txn.Expire(ctx, key, expiration)
_, err := txn.Exec(ctx)
if err != nil {
return fmt.Errorf("failed to execute transaction: %w", err)
}
return nil
}
func RedisHGetObj(key string, obj interface{}) error {
ctx := context.Background()
result, err := RDB.HGetAll(ctx, key).Result()
if err != nil {
return fmt.Errorf("failed to load hash from Redis: %w", err)
}
if len(result) == 0 {
return fmt.Errorf("key %s not found in Redis", key)
}
// Handle both pointer and non-pointer values
val := reflect.ValueOf(obj)
if val.Kind() != reflect.Ptr {
return fmt.Errorf("obj must be a pointer to a struct, got %T", obj)
}
v := val.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("obj must be a pointer to a struct, got pointer to %T", v.Interface())
}
t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
fieldName := field.Name
if value, ok := result[fieldName]; ok {
fieldValue := v.Field(i)
// Handle pointer types
if fieldValue.Kind() == reflect.Ptr {
if value == "" {
continue
}
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
}
fieldValue = fieldValue.Elem()
}
// Enhanced type handling for Token struct
switch fieldValue.Kind() {
case reflect.String:
fieldValue.SetString(value)
case reflect.Int, reflect.Int64:
intValue, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse int field %s: %w", fieldName, err)
}
fieldValue.SetInt(intValue)
case reflect.Bool:
boolValue, err := strconv.ParseBool(value)
if err != nil {
return fmt.Errorf("failed to parse bool field %s: %w", fieldName, err)
}
fieldValue.SetBool(boolValue)
case reflect.Struct:
// Special handling for gorm.DeletedAt
if fieldValue.Type().String() == "gorm.DeletedAt" {
if value != "" {
timeValue, err := time.Parse(time.RFC3339, value)
if err != nil {
return fmt.Errorf("failed to parse DeletedAt field %s: %w", fieldName, err)
}
fieldValue.Set(reflect.ValueOf(gorm.DeletedAt{Time: timeValue, Valid: true}))
}
}
default:
return fmt.Errorf("unsupported field type: %s for field %s", fieldValue.Kind(), fieldName)
}
}
}
return nil
}
// RedisIncr Add this function to handle atomic increments
func RedisIncr(key string, delta int64) error {
// 检查键的剩余生存时间
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil {
// 失败则尝试直接减少
return RDB.DecrBy(context.Background(), key, value).Err()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get TTL: %w", err)
}
// 如果剩余生存时间大于0则进行减少操作
// 只有在 key 存在且有 TTL 时才需要特殊处理
if ttl > 0 {
ctx := context.Background()
// 开始一个Redis事务
txn := RDB.TxPipeline()
// 减少余额
decrCmd := txn.DecrBy(ctx, key, value)
decrCmd := txn.IncrBy(ctx, key, delta)
if err := decrCmd.Err(); err != nil {
return err // 如果减少失败,则直接返回错误
}
@@ -99,8 +233,54 @@ func RedisDecrease(key string, value int64) error {
// 执行事务
_, err = txn.Exec(ctx)
return err
} else {
_ = RedisDel(key)
}
return nil
}
func RedisHIncrBy(key, field string, delta int64) error {
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get TTL: %w", err)
}
if ttl > 0 {
ctx := context.Background()
txn := RDB.TxPipeline()
incrCmd := txn.HIncrBy(ctx, key, field, delta)
if err := incrCmd.Err(); err != nil {
return err
}
txn.Expire(ctx, key, ttl)
_, err = txn.Exec(ctx)
return err
}
return nil
}
func RedisHSetField(key, field string, value interface{}) error {
ttlCmd := RDB.TTL(context.Background(), key)
ttl, err := ttlCmd.Result()
if err != nil && !errors.Is(err, redis.Nil) {
return fmt.Errorf("failed to get TTL: %w", err)
}
if ttl > 0 {
ctx := context.Background()
txn := RDB.TxPipeline()
hsetCmd := txn.HSet(ctx, key, field, value)
if err := hsetCmd.Err(); err != nil {
return err
}
txn.Expire(ctx, key, ttl)
_, err = txn.Exec(ctx)
return err
}
return nil
}

View File

@@ -35,9 +35,7 @@ func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return map[string]interface{}{
"result": str,
}
return nil
}
return m
}

View File

@@ -1,46 +0,0 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
if userGroup == "" {
// 如果userGroup为空返回UserUsableGroups
return UserUsableGroups
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := UserUsableGroups[userGroup]; !ok {
appendUserUsableGroups := make(map[string]string)
for k, v := range UserUsableGroups {
appendUserUsableGroups[k] = v
}
appendUserUsableGroups[userGroup] = "用户分组"
return appendUserUsableGroups
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return UserUsableGroups
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

23
constant/cache_key.go Normal file
View File

@@ -0,0 +1,23 @@
package constant
import "one-api/common"
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
// Cache keys
const (
UserGroupKeyFmt = "user_group:%d"
UserQuotaKeyFmt = "user_quota:%d"
UserEnabledKeyFmt = "user_enabled:%d"
UserUsernameKeyFmt = "user_name:%d"
)
const (
TokenFiledRemainQuota = "RemainQuota"
TokenFieldGroup = "Group"
)

View File

@@ -10,6 +10,8 @@ import (
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
var MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)

View File

@@ -1,6 +1,9 @@
package constant
var (
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonLength = "length"
FinishReasonFunctionCall = "function_call"
FinishReasonContentFilter = "content_filter"
)

View File

@@ -21,7 +21,7 @@ func GetSubscription(c *gin.Context) {
usedQuota = token.UsedQuota
} else {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
remainQuota, err = model.GetUserQuota(userId, false)
usedQuota, err = model.GetUserUsedQuota(userId)
}
if expiredTime <= 0 {

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"io"
"math"
"net/http"
@@ -24,6 +23,8 @@ import (
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
)
@@ -32,6 +33,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
if channel.Type == common.ChannelTypeMidjourney {
return errors.New("midjourney channel test is not supported"), nil
}
if channel.Type == common.ChannelTypeMidjourneyPlus {
return errors.New("midjourney plus channel test is not supported!!!"), nil
}
if channel.Type == common.ChannelTypeSunoAPI {
return errors.New("suno channel test is not supported"), nil
}
@@ -152,8 +156,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
if strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
if strings.HasPrefix(model, "o1") {
testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
testRequest.MaxTokens = 2
} else {

View File

@@ -115,8 +115,8 @@ func FetchUpstreamModels(c *gin.Context) {
// return
//}
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", baseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
@@ -419,7 +419,8 @@ func EditTagChannels(c *gin.Context) {
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
@@ -570,3 +571,29 @@ func FetchModels(c *gin.Context) {
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}

View File

@@ -3,13 +3,13 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
for groupName, _ := range setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -20,15 +20,18 @@ func GetGroups(c *gin.Context) {
}
func GetUserGroups(c *gin.Context) {
usableGroups := make(map[string]string)
usableGroups := make(map[string]map[string]interface{})
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
userGroup, _ = model.GetUserGroup(userId, false)
for groupName, ratio := range setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if desc, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = map[string]interface{}{
"ratio": ratio,
"desc": desc,
}
}
}
c.JSON(http.StatusOK, gin.H{

View File

@@ -166,7 +166,7 @@ func ListModels(c *gin.Context) {
}
} else {
userId := c.GetInt("id")
userGroup, err := model.GetUserGroup(userId)
userGroup, err := model.GetUserGroup(userId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
"strings"
"github.com/gin-gonic/gin"
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
err = common.CheckGroupRatio(option.Value)
err = setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

69
controller/playground.go Normal file
View File

@@ -0,0 +1,69 @@
package controller
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/middleware"
"one-api/model"
"one-api/service"
"one-api/setting"
"time"
)
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
c.Set(constant.ContextKeyRequestStartTime, time.Now())
Relay(c)
}

View File

@@ -4,14 +4,38 @@ import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"one-api/setting"
)
func GetPricing(c *gin.Context) {
pricing := model.GetPricing()
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
for s, f := range setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
if exists {
user, err := model.GetUserById(userId.(int), false)
if err == nil {
group = user.Group
}
}
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
}
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": common.GroupRatio,
"success": true,
"data": pricing,
"group_ratio": groupRatio,
"usable_group": usableGroup,
})
}

View File

@@ -1,19 +1,24 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-gonic/gin"
)
func GetAllRedemptions(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
redemptions, total, err := model.GetAllRedemptions((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -24,14 +29,27 @@ func GetAllRedemptions(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": redemptions,
"data": gin.H{
"items": redemptions,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
func SearchRedemptions(c *gin.Context) {
keyword := c.Query("keyword")
redemptions, err := model.SearchRedemptions(keyword)
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 0 {
p = 0
}
if pageSize < 1 {
pageSize = common.ItemsPerPage
}
redemptions, total, err := model.SearchRedemptions(keyword, (p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -42,7 +60,12 @@ func SearchRedemptions(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": redemptions,
"data": gin.H{
"items": redemptions,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}

View File

@@ -48,58 +48,6 @@ func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErr
return err
}
func Playground(c *gin.Context) {
var openaiErr *dto.OpenAIErrorWithStatusCode
defer func() {
if openaiErr != nil {
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.Error,
})
}
}()
useAccessToken := c.GetBool("use_access_token")
if useAccessToken {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("暂不支持使用 access token"), "access_token_not_supported", http.StatusBadRequest)
return
}
playgroundRequest := &dto.PlayGroundRequest{}
err := common.UnmarshalBodyReusable(c, playgroundRequest)
if err != nil {
openaiErr = service.OpenAIErrorWrapperLocal(err, "unmarshal_request_failed", http.StatusBadRequest)
return
}
if playgroundRequest.Model == "" {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("请选择模型"), "model_required", http.StatusBadRequest)
return
}
c.Set("original_model", playgroundRequest.Model)
group := playgroundRequest.Group
userGroup := c.GetString("group")
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}
c.Set("group", group)
}
c.Set("token_name", "playground-"+group)
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
return
}
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
Relay(c)
}
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)

View File

@@ -153,7 +153,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
//err = model.CacheUpdateUserQuota(task.UserId) ?
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {

View File

@@ -75,7 +75,7 @@ func RequestEpay(c *gin.Context) {
}
id := c.GetInt("id")
group, err := model.CacheGetUserGroup(id)
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return
@@ -236,7 +236,7 @@ func RequestAmount(c *gin.Context) {
return
}
id := c.GetInt("id")
group, err := model.CacheGetUserGroup(id)
group, err := model.GetUserGroup(id, true)
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
return

View File

@@ -6,13 +6,15 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
"strconv"
"strings"
"sync"
"one-api/constant"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"one-api/constant"
)
type LoginRequest struct {
@@ -241,10 +243,14 @@ func Register(c *gin.Context) {
func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
users, total, err := model.GetAllUsers((p-1)*pageSize, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -255,7 +261,12 @@ func GetAllUsers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
"data": gin.H{
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
@@ -263,7 +274,16 @@ func GetAllUsers(c *gin.Context) {
func SearchUsers(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
users, err := model.SearchUsers(keyword, group)
p, _ := strconv.Atoi(c.Query("p"))
pageSize, _ := strconv.Atoi(c.Query("page_size"))
if p < 1 {
p = 1
}
if pageSize < 0 {
pageSize = common.ItemsPerPage
}
startIdx := (p - 1) * pageSize
users, total, err := model.SearchUsers(keyword, group, startIdx, pageSize)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -274,7 +294,12 @@ func SearchUsers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
"data": gin.H{
"items": users,
"total": total,
"page": p,
"page_size": pageSize,
},
})
return
}
@@ -454,7 +479,15 @@ func GetUserModels(c *gin.Context) {
})
return
}
models := model.GetGroupModels(user.Group)
groups := setting.GetUserUsableGroups(user.Group)
var models []string
for group := range groups {
for _, g := range model.GetGroupModels(group) {
if !common.StringsContains(models, g) {
models = append(models, g)
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

8
dto/file_data.go Normal file
View File

@@ -0,0 +1,8 @@
package dto
type LocalFileData struct {
MimeType string
Base64Data string
Url string
Size int64
}

View File

@@ -86,6 +86,10 @@ type ToolCall struct {
Function FunctionCall `json:"function"`
}
func (c *ToolCall) SetIndex(i int) {
c.Index = &i
}
type FunctionCall struct {
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`

View File

@@ -33,9 +33,11 @@ var indexPage []byte
func main() {
err := godotenv.Load(".env")
if err != nil {
common.SysError("failed to load .env file: " + err.Error())
common.SysLog("Support for .env file is disabled")
}
common.LoadEnv()
common.SetupLogger()
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
@@ -80,9 +82,6 @@ func main() {
common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency))
model.InitChannelCache()
}
if common.RedisEnabled {
go model.SyncTokenCache(common.SyncFrequency)
}
if common.MemoryCacheEnabled {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)

View File

@@ -201,7 +201,7 @@ func TokenAuth() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
userEnabled, err := model.IsUserEnabled(token.UserId, false)
if err != nil {
abortWithOpenAiMessage(c, http.StatusInternalServerError, err.Error())
return

View File

@@ -10,6 +10,7 @@ import (
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strconv"
"strings"
"time"
@@ -39,16 +40,16 @@ func Distribute() func(c *gin.Context) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error())
return
}
userGroup, _ := model.CacheGetUserGroup(userId)
userGroup, _ := model.GetUserGroup(userId, false)
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}

View File

@@ -3,10 +3,11 @@ package model
import (
"errors"
"fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
"github.com/samber/lo"
"gorm.io/gorm"
)
type Ability struct {
@@ -22,10 +23,6 @@ type Ability struct {
func GetGroupModels(group string) []string {
var models []string
// Find distinct models
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
return models
}
@@ -44,10 +41,8 @@ func GetAllEnableAbilities() []Ability {
}
func getPriority(group string, model string, retry int) (int, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
@@ -80,10 +75,8 @@ func getPriority(group string, model string, retry int) (int, error) {
}
func getChannelQuery(group string, model string, retry int) *gorm.DB {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
@@ -173,18 +166,67 @@ func (channel *Channel) DeleteAbilities() error {
// UpdateAbilities updates abilities of this channel.
// Make sure the channel is completed before calling this function.
func (channel *Channel) UpdateAbilities() error {
// A quick and dirty way to update abilities
func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
isNewTx := false
// 如果没有传入事务,创建新的事务
if tx == nil {
tx = DB.Begin()
if tx.Error != nil {
return tx.Error
}
isNewTx = true
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
}
// First delete all abilities of this channel
err := channel.DeleteAbilities()
err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
// Then add new abilities
err = channel.AddAbilities()
if err != nil {
return err
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
ability := Ability{
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
}
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
err = tx.Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
}
}
// 如果是新创建的事务,需要提交
if isNewTx {
return tx.Commit().Error
}
return nil
}
@@ -246,7 +288,7 @@ func FixAbility() (int, error) {
return 0, err
}
for _, channel := range channels {
err := channel.UpdateAbilities()
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {

View File

@@ -1,209 +1,115 @@
package model
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
"time"
)
var (
TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = common.SyncFrequency
)
// 仅用于定时同步缓存
var token2UserId = make(map[string]int)
var token2UserIdLock sync.RWMutex
func cacheSetToken(token *Token) error {
jsonBytes, err := json.Marshal(token)
if err != nil {
return err
}
err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
return err
}
token2UserIdLock.Lock()
defer token2UserIdLock.Unlock()
token2UserId[token.Key] = token.UserId
return nil
}
// CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
func CacheGetTokenByKey(key string) (*Token, error) {
if !common.RedisEnabled {
return GetTokenByKey(key)
}
var token *Token
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
// 如果缓存中不存在,则从数据库中获取
token, err = GetTokenByKey(key)
if err != nil {
return nil, err
}
err = cacheSetToken(token)
return token, nil
}
// 如果缓存中存在,则续期时间
err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
err = json.Unmarshal([]byte(tokenObjectString), &token)
return token, err
}
func SyncTokenCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing tokens from database")
token2UserIdLock.Lock()
// 从token2UserId中获取所有的key
var copyToken2UserId = make(map[string]int)
for s, i := range token2UserId {
copyToken2UserId[s] = i
}
token2UserId = make(map[string]int)
token2UserIdLock.Unlock()
for key := range copyToken2UserId {
token, err := GetTokenByKey(key)
if err != nil {
// 如果数据库中不存在,则删除缓存
common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
//delete redis
err := common.RedisDel(fmt.Sprintf("token:%s", key))
if err != nil {
common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
}
} else {
// 如果数据库中存在先检查redis
_, err = common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
// 如果redis中不存在则跳过
continue
}
err = cacheSetToken(token)
if err != nil {
common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
}
}
}
}
}
func CacheGetUserGroup(id int) (group string, err error) {
if !common.RedisEnabled {
return GetUserGroup(id)
}
group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
if err != nil {
group, err = GetUserGroup(id)
if err != nil {
return "", err
}
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
}
}
return group, err
}
func CacheGetUsername(id int) (username string, err error) {
if !common.RedisEnabled {
return GetUsernameById(id)
}
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
if err != nil {
username, err = GetUsernameById(id)
if err != nil {
return "", err
}
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
}
}
return username, err
}
func CacheGetUserQuota(id int) (quota int, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
}
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
}
quota, err = strconv.Atoi(quotaString)
return quota, err
}
func CacheUpdateUserQuota(id int) error {
if !common.RedisEnabled {
return nil
}
quota, err := GetUserQuota(id)
if err != nil {
return err
}
return cacheSetUserQuota(id, quota)
}
func cacheSetUserQuota(id int, quota int) error {
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
return err
}
func CacheDecreaseUserQuota(id int, quota int) error {
if !common.RedisEnabled {
return nil
}
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
return err
}
func CacheIsUserEnabled(userId int) (bool, error) {
if !common.RedisEnabled {
return IsUserEnabled(userId)
}
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
userEnabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if userEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
}
//func CacheGetUserGroup(id int) (group string, err error) {
// if !common.RedisEnabled {
// return GetUserGroup(id)
// }
// group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
// if err != nil {
// group, err = GetUserGroup(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return group, err
//}
//
//func CacheGetUsername(id int) (username string, err error) {
// if !common.RedisEnabled {
// return GetUsernameById(id)
// }
// username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
// if err != nil {
// username, err = GetUsernameById(id)
// if err != nil {
// return "", err
// }
// err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(constant.UserId2GroupCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user group error: " + err.Error())
// }
// }
// return username, err
//}
//
//func CacheGetUserQuota(id int) (quota int, err error) {
// if !common.RedisEnabled {
// return GetUserQuota(id)
// }
// quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
// if err != nil {
// quota, err = GetUserQuota(id)
// if err != nil {
// return 0, err
// }
// return quota, nil
// }
// quota, err = strconv.Atoi(quotaString)
// return quota, nil
//}
//
//func CacheUpdateUserQuota(id int) error {
// if !common.RedisEnabled {
// return nil
// }
// quota, err := GetUserQuota(id)
// if err != nil {
// return err
// }
// return cacheSetUserQuota(id, quota)
//}
//
//func cacheSetUserQuota(id int, quota int) error {
// err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second)
// return err
//}
//
//func CacheDecreaseUserQuota(id int, quota int) error {
// if !common.RedisEnabled {
// return nil
// }
// err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
// return err
//}
//
//func CacheIsUserEnabled(userId int) (bool, error) {
// if !common.RedisEnabled {
// return IsUserEnabled(userId)
// }
// enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
// if err == nil {
// return enabled == "1", nil
// }
//
// userEnabled, err := IsUserEnabled(userId)
// if err != nil {
// return false, err
// }
// enabled = "0"
// if userEnabled {
// enabled = "1"
// }
// err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(constant.UserId2StatusCacheSeconds)*time.Second)
// if err != nil {
// common.SysError("Redis set user enabled error: " + err.Error())
// }
// return userEnabled, err
//}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
@@ -344,12 +250,12 @@ func CacheGetChannel(id int) (*Channel, error) {
}
func CacheUpdateChannelStatus(id int, status int) {
if (!common.MemoryCacheEnabled) {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel, ok := channelsIDM[id]; ok {
channel.Status = status
}
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel, ok := channelsIDM[id]; ok {
channel.Status = status
}
}

View File

@@ -114,14 +114,11 @@ func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) {
func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) {
var channels []*Channel
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}
@@ -257,7 +254,7 @@ func (channel *Channel) Update() error {
return err
}
DB.Model(channel).First(channel, "id = ?", channel.Id)
err = channel.UpdateAbilities()
err = channel.UpdateAbilities(nil)
return err
}
@@ -389,7 +386,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
channels, err := GetChannelsByTag(updatedTag, false)
if err == nil {
for _, channel := range channels {
err = channel.UpdateAbilities()
err = channel.UpdateAbilities(nil)
if err != nil {
common.SysError("failed to update abilities: " + err.Error())
}
@@ -437,14 +434,10 @@ func GetPaginatedTags(offset int, limit int) ([]*string, error) {
func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) {
var tags []*string
keyCol := "`key`"
groupCol := "`group`"
modelsCol := "`models`"
// 如果是 PostgreSQL使用双引号
if common.UsingPostgreSQL {
keyCol = `"key"`
groupCol = `"group"`
modelsCol = `"models"`
}
@@ -509,3 +502,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
}
channel.Setting = string(settingBytes)
}
func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error
return channels, err
}
func BatchSetChannelTag(ids []int, tag *string) error {
// 开启事务
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
// 更新标签
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
if err != nil {
tx.Rollback()
return err
}
// update ability status
channels, err := GetChannelsByIds(ids)
if err != nil {
tx.Rollback()
return err
}
for _, channel := range channels {
err = channel.UpdateAbilities(tx)
if err != nil {
tx.Rollback()
return err
}
}
// 提交事务
return tx.Commit().Error
}

View File

@@ -12,16 +12,6 @@ import (
"gorm.io/gorm"
)
var groupCol string
func init() {
if common.UsingPostgreSQL {
groupCol = `"group"`
} else {
groupCol = "`group`"
}
}
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
@@ -50,16 +40,30 @@ const (
LogTypeSystem
)
func formatUserLogs(logs []*Log) {
for i := range logs {
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")
}
logs[i].Other = common.MapToJsonStr(otherMap)
logs[i].Id = logs[i].Id % 1024
}
}
func GetLogByKey(key string) (logs []*Log, err error) {
if os.Getenv("LOG_SQL_DSN") != "" {
var tk Token
if err = DB.Model(&Token{}).Where("`key`=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
return nil, err
}
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
} else {
err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
}
formatUserLogs(logs)
return logs, err
}
@@ -67,7 +71,7 @@ func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
username, _ := GetUsernameById(userId, false)
log := &Log{
UserId: userId,
Username: username,
@@ -88,7 +92,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
if !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
username, _ := GetUsernameById(userId, false)
otherStr := common.MapToJsonStr(other)
log := &Log{
UserId: userId,
@@ -184,16 +188,8 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if err != nil {
return nil, 0, err
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error
for i := range logs {
var otherMap map[string]interface{}
otherMap = common.StrToMap(logs[i].Other)
if otherMap != nil {
// delete admin
delete(otherMap, "admin_info")
}
logs[i].Other = common.MapToJsonStr(otherMap)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
formatUserLogs(logs)
return logs, total, err
}
@@ -203,7 +199,8 @@ func SearchAllLogs(keyword string) (logs []*Log, err error) {
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
formatUserLogs(logs)
return logs, err
}

View File

@@ -13,6 +13,20 @@ import (
"time"
)
var groupCol string
var keyCol string
func initCol() {
if common.UsingPostgreSQL {
groupCol = `"group"`
keyCol = `"key"`
} else {
groupCol = "`group`"
keyCol = "`key`"
}
}
var DB *gorm.DB
var LOG_DB *gorm.DB
@@ -41,6 +55,9 @@ func createRootAccountIfNeed() error {
}
func chooseDB(envName string) (*gorm.DB, error) {
defer func() {
initCol()
}()
dsn := os.Getenv(envName)
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") {

View File

@@ -87,8 +87,8 @@ func InitOptionMap() {
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":

View File

@@ -3,8 +3,10 @@ package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strconv"
"gorm.io/gorm"
)
type Redemption struct {
@@ -21,16 +23,80 @@ type Redemption struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
var redemptions []*Redemption
var err error
err = DB.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
return redemptions, err
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
// 开始事务
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 获取总数
err = tx.Model(&Redemption{}).Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 获取分页数据
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 提交事务
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return redemptions, total, nil
}
func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) {
err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error
return redemptions, err
func SearchRedemptions(keyword string, startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Build query based on keyword type
query := tx.Model(&Redemption{})
// Only try to convert to ID if the string represents a valid integer
if id, err := strconv.Atoi(keyword); err == nil {
query = query.Where("id = ? OR name LIKE ?", id, keyword+"%")
} else {
query = query.Where("name LIKE ?", keyword+"%")
}
// Get total count
err = query.Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Get paginated data
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&redemptions).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return redemptions, total, nil
}
func GetRedemptionById(id int) (*Redemption, error) {

View File

@@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
"one-api/common"
relaycommon "one-api/relay/common"
@@ -30,6 +31,10 @@ type Token struct {
DeletedAt gorm.DeletedAt `gorm:"index"`
}
func (token *Token) Clean() {
token.Key = ""
}
func (token *Token) GetIpLimitsMap() map[string]any {
// delete empty spaces
//split with \n
@@ -63,7 +68,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
if token != "" {
token = strings.Trim(token, "sk-")
}
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where("`key` LIKE ?", "%"+token+"%").Find(&tokens).Error
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
return tokens, err
}
@@ -71,7 +76,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if key == "" {
return nil, errors.New("未提供令牌")
}
token, err = CacheGetTokenByKey(key)
token, err = GetTokenByKey(key, false)
if err == nil {
if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3]
@@ -128,22 +133,38 @@ func GetTokenById(id int) (*Token, error) {
token := Token{Id: id}
var err error = nil
err = DB.First(&token, "id = ?", id).Error
if err != nil {
if common.RedisEnabled {
go cacheSetToken(&token)
}
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
if err := cacheSetToken(token); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
return &token, err
}
func GetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) && token != nil {
gopool.Go(func() {
if err := cacheSetToken(*token); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
token, err := cacheGetTokenByKey(key)
if err == nil {
return token, nil
}
// Don't return error - fall through to DB
}
var token Token
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
fromDB = true
err = DB.Where(keyCol+" = ?", key).First(&token).Error
return token, err
}
func (token *Token) Insert() error {
@@ -153,20 +174,48 @@ func (token *Token) Insert() error {
}
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
func (token *Token) Update() (err error) {
defer func() {
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
common.SysError("failed to update token cache: " + err.Error())
}
})
}
}()
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota",
"model_limits_enabled", "model_limits", "allow_ips", "group").Updates(token).Error
return err
}
func (token *Token) SelectUpdate() error {
func (token *Token) SelectUpdate() (err error) {
defer func() {
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
err := cacheSetToken(*token)
if err != nil {
common.SysError("failed to update token cache: " + err.Error())
}
})
}
}()
// This can update zero values
return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
}
func (token *Token) Delete() error {
var err error
func (token *Token) Delete() (err error) {
defer func() {
if shouldUpdateRedis(true, err) {
gopool.Go(func() {
err := cacheDeleteToken(token.Key)
if err != nil {
common.SysError("failed to delete token cache: " + err.Error())
}
})
}
}()
err = DB.Delete(token).Error
return err
}
@@ -214,10 +263,18 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
func IncreaseTokenQuota(id int, quota int) (err error) {
func IncreaseTokenQuota(id int, key string, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.RedisEnabled {
gopool.Go(func() {
err := cacheIncrTokenQuota(key, int64(quota))
if err != nil {
common.SysError("failed to increase token quota: " + err.Error())
}
})
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
@@ -236,10 +293,18 @@ func increaseTokenQuota(id int, quota int) (err error) {
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
func DecreaseTokenQuota(id int, key string, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
if common.RedisEnabled {
gopool.Go(func() {
err := cacheDecrTokenQuota(key, int64(quota))
if err != nil {
common.SysError("failed to decrease token quota: " + err.Error())
}
})
}
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
@@ -258,37 +323,31 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err
}
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) (userQuota int, err error) {
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
if quota < 0 {
return 0, errors.New("quota 不能为负数!")
return errors.New("quota 不能为负数!")
}
if !relayInfo.IsPlayground {
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return 0, err
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return 0, errors.New("令牌额度不足")
}
if relayInfo.IsPlayground {
return nil
}
userQuota, err = GetUserQuota(relayInfo.UserId)
//if relayInfo.TokenUnlimited {
// return nil
//}
token, err := GetTokenById(relayInfo.TokenId)
if err != nil {
return 0, err
return err
}
if userQuota < quota {
return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return errors.New("令牌额度不足")
}
if !relayInfo.IsPlayground {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
if err != nil {
return 0, err
}
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
return err
}
err = DecreaseUserQuota(relayInfo.UserId, quota)
return userQuota - quota, err
return nil
}
func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
func PostConsumeQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
if quota > 0 {
err = DecreaseUserQuota(relayInfo.UserId, quota)
@@ -301,9 +360,9 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
if !relayInfo.IsPlayground {
if quota > 0 {
err = DecreaseTokenQuota(relayInfo.TokenId, quota)
err = DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
} else {
err = IncreaseTokenQuota(relayInfo.TokenId, -quota)
err = IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, -quota)
}
if err != nil {
return err

64
model/token_cache.go Normal file
View File

@@ -0,0 +1,64 @@
package model
import (
"fmt"
"one-api/common"
"one-api/constant"
"time"
)
func cacheSetToken(token Token) error {
key := common.GenerateHMAC(token.Key)
token.Clean()
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
if err != nil {
return err
}
return nil
}
func cacheDeleteToken(key string) error {
key = common.GenerateHMAC(key)
err := common.RedisHDelObj(fmt.Sprintf("token:%s", key))
if err != nil {
return err
}
return nil
}
func cacheIncrTokenQuota(key string, increment int64) error {
key = common.GenerateHMAC(key)
err := common.RedisHIncrBy(fmt.Sprintf("token:%s", key), constant.TokenFiledRemainQuota, increment)
if err != nil {
return err
}
return nil
}
func cacheDecrTokenQuota(key string, decrement int64) error {
return cacheIncrTokenQuota(key, -decrement)
}
func cacheSetTokenField(key string, field string, value string) error {
key = common.GenerateHMAC(key)
err := common.RedisHSetField(fmt.Sprintf("token:%s", key), field, value)
if err != nil {
return err
}
return nil
}
// CacheGetTokenByKey 从缓存中获取 token如果缓存中不存在则从数据库中获取
func cacheGetTokenByKey(key string) (*Token, error) {
hmacKey := common.GenerateHMAC(key)
if !common.RedisEnabled {
return nil, nil
}
var token Token
err := common.RedisHGetObj(fmt.Sprintf("token:%s", hmacKey), &token)
if err != nil {
return nil, err
}
token.Key = key
return &token, nil
}

View File

@@ -6,7 +6,8 @@ import (
"one-api/common"
"strconv"
"strings"
"time"
"github.com/bytedance/gopkg/util/gopool"
"gorm.io/gorm"
)
@@ -80,46 +81,100 @@ func GetMaxUserId() int {
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
return users, err
func GetAllUsers(startIdx int, num int) (users []*User, total int64, err error) {
// Start transaction
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// Get total count within transaction
err = tx.Unscoped().Model(&User{}).Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Get paginated users within same transaction
err = tx.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// Commit transaction
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func SearchUsers(keyword string, group string) ([]*User, error) {
func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User, int64, error) {
var users []*User
var total int64
var err error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
// 开始事务
tx := DB.Begin()
if tx.Error != nil {
return nil, 0, tx.Error
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 构建基础查询
query := tx.Unscoped().Model(&User{})
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功按照ID和可选的组别搜索用户
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
if group != "" {
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
query = query.Where("id = ? AND "+groupCol+" = ?", keywordInt, group)
} else {
query = query.Where("id = ?", keywordInt)
}
err = query.Find(&users).Error
if err != nil || len(users) > 0 {
return users, err
}
}
err = nil
query := DB.Unscoped().Omit("password")
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
// 如果不是ID搜索则使用模糊匹配
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition,
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}
}
err = query.Find(&users).Error
return users, err
// 获取总数
err = query.Count(&total).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 获取分页数据
err = query.Omit("password").Order("id desc").Limit(num).Offset(startIdx).Find(&users).Error
if err != nil {
tx.Rollback()
return nil, 0, err
}
// 提交事务
if err = tx.Commit().Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func GetUserById(id int, selectAll bool) (*User, error) {
@@ -251,14 +306,12 @@ func (user *User) Update(updatePassword bool) error {
}
newUser := *user
DB.First(&user, user.Id)
err = DB.Model(user).Updates(newUser).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
if err = DB.Model(user).Updates(newUser).Error; err != nil {
return err
}
return err
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
}
func (user *User) Edit(updatePassword bool) error {
@@ -269,6 +322,7 @@ func (user *User) Edit(updatePassword bool) error {
return err
}
}
newUser := *user
updates := map[string]interface{}{
"username": newUser.Username,
@@ -279,23 +333,26 @@ func (user *User) Edit(updatePassword bool) error {
if updatePassword {
updates["password"] = newUser.Password
}
DB.First(&user, user.Id)
err = DB.Model(user).Updates(updates).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
_ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
}
if err = DB.Model(user).Updates(updates).Error; err != nil {
return err
}
return err
// 更新缓存
return updateUserCache(user.Id, user.Username, user.Group, user.Quota, user.Status)
}
func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Delete(user).Error
return err
if err := DB.Delete(user).Error; err != nil {
return err
}
// 清除缓存
return invalidateUserCache(user.Id)
}
func (user *User) HardDelete() error {
@@ -409,15 +466,33 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser
}
func IsUserEnabled(userId int) (bool, error) {
if userId == 0 {
return false, errors.New("user id is empty")
// IsUserEnabled checks user status from Redis first, falls back to DB if needed
func IsUserEnabled(id int, fromDB bool) (status bool, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserStatusCache(id, status); err != nil {
common.SysError("failed to update user status cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
// Try Redis first
status, err := getUserStatusCache(id)
if err == nil {
return status == common.UserStatusEnabled, nil
}
// Don't return error - fall through to DB
}
fromDB = true
var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
err = DB.Where("id = ?", id).Select("status").Find(&user).Error
if err != nil {
return false, err
}
return user.Status == common.UserStatusEnabled, nil
}
@@ -433,14 +508,33 @@ func ValidateAccessToken(token string) (user *User) {
return nil
}
func GetUserQuota(id int) (quota int, err error) {
// GetUserQuota gets quota from Redis first, falls back to DB if needed
func GetUserQuota(id int, fromDB bool) (quota int, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserQuotaCache(id, quota); err != nil {
common.SysError("failed to update user quota cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
quota, err := getUserQuotaCache(id)
if err == nil {
return quota, nil
}
// Don't return error - fall through to DB
//common.SysError("failed to get user quota from cache: " + err.Error())
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
if err != nil {
if common.RedisEnabled {
go cacheSetUserQuota(id, quota)
}
return 0, err
}
return quota, err
return quota, nil
}
func GetUserUsedQuota(id int) (quota int, err error) {
@@ -453,20 +547,44 @@ func GetUserEmail(id int) (email string, err error) {
return email, err
}
func GetUserGroup(id int) (group string, err error) {
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
// GetUserGroup gets group from Redis first, falls back to DB if needed
func GetUserGroup(id int, fromDB bool) (group string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserGroupCache(id, group); err != nil {
common.SysError("failed to update user group cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
group, err := getUserGroupCache(id)
if err == nil {
return group, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
if err != nil {
return "", err
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err
return group, nil
}
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheIncrUserQuota(id, int64(quota))
if err != nil {
common.SysError("failed to increase user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
@@ -476,6 +594,9 @@ func IncreaseUserQuota(id int, quota int) (err error) {
func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
if err != nil {
return err
}
return err
}
@@ -483,6 +604,12 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
gopool.Go(func() {
err := cacheDecrUserQuota(id, int64(quota))
if err != nil {
common.SysError("failed to decrease user quota: " + err.Error())
}
})
if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
@@ -492,9 +619,23 @@ func DecreaseUserQuota(id int, quota int) (err error) {
func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
if err != nil {
return err
}
return err
}
func DeltaUpdateUserQuota(id int, delta int) (err error) {
if delta == 0 {
return nil
}
if delta > 0 {
return IncreaseUserQuota(id, delta)
} else {
return DecreaseUserQuota(id, -delta)
}
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
@@ -518,7 +659,13 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
).Error
if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error())
return
}
//// 更新缓存
//if err := invalidateUserCache(id); err != nil {
// common.SysError("failed to invalidate user cache: " + err.Error())
//}
}
func updateUserUsedQuota(id int, quota int) {
@@ -539,9 +686,32 @@ func updateUserRequestCount(id int, count int) {
}
}
func GetUsernameById(id int) (username string, err error) {
// GetUsernameById gets username from Redis first, falls back to DB if needed
func GetUsernameById(id int, fromDB bool) (username string, err error) {
defer func() {
// Update Redis cache asynchronously on successful DB read
if shouldUpdateRedis(fromDB, err) {
gopool.Go(func() {
if err := updateUserNameCache(id, username); err != nil {
common.SysError("failed to update user name cache: " + err.Error())
}
})
}
}()
if !fromDB && common.RedisEnabled {
username, err := getUserNameCache(id)
if err == nil {
return username, nil
}
// Don't return error - fall through to DB
}
fromDB = true
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
return username, err
if err != nil {
return "", err
}
return username, nil
}
func IsLinuxDOIdAlreadyTaken(linuxDOId string) bool {

206
model/user_cache.go Normal file
View File

@@ -0,0 +1,206 @@
package model
import (
"fmt"
"one-api/common"
"one-api/constant"
"strconv"
"time"
)
// Change UserCache struct to userCache
type userCache struct {
Id int `json:"id"`
Group string `json:"group"`
Quota int `json:"quota"`
Status int `json:"status"`
Role int `json:"role"`
Username string `json:"username"`
}
// Rename all exported functions to private ones
// invalidateUserCache clears all user related cache
func invalidateUserCache(userId int) error {
if !common.RedisEnabled {
return nil
}
keys := []string{
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
}
for _, key := range keys {
if err := common.RedisDel(key); err != nil {
return fmt.Errorf("failed to delete cache key %s: %w", key, err)
}
}
return nil
}
// updateUserGroupCache updates user group cache
func updateUserGroupCache(userId int, group string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserGroupKeyFmt, userId),
group,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserQuotaCache updates user quota cache
func updateUserQuotaCache(userId int, quota int) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserQuotaKeyFmt, userId),
fmt.Sprintf("%d", quota),
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserStatusCache updates user status cache
func updateUserStatusCache(userId int, userEnabled bool) error {
if !common.RedisEnabled {
return nil
}
enabled := "0"
if userEnabled {
enabled = "1"
}
return common.RedisSet(
fmt.Sprintf(constant.UserEnabledKeyFmt, userId),
enabled,
time.Duration(constant.UserId2StatusCacheSeconds)*time.Second,
)
}
// updateUserNameCache updates username cache
func updateUserNameCache(userId int, username string) error {
if !common.RedisEnabled {
return nil
}
return common.RedisSet(
fmt.Sprintf(constant.UserUsernameKeyFmt, userId),
username,
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
)
}
// updateUserCache updates all user cache fields
func updateUserCache(userId int, username string, userGroup string, quota int, status int) error {
if !common.RedisEnabled {
return nil
}
if err := updateUserGroupCache(userId, userGroup); err != nil {
return fmt.Errorf("update group cache: %w", err)
}
if err := updateUserQuotaCache(userId, quota); err != nil {
return fmt.Errorf("update quota cache: %w", err)
}
if err := updateUserStatusCache(userId, status == common.UserStatusEnabled); err != nil {
return fmt.Errorf("update status cache: %w", err)
}
if err := updateUserNameCache(userId, username); err != nil {
return fmt.Errorf("update username cache: %w", err)
}
return nil
}
// getUserGroupCache gets user group from cache
func getUserGroupCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserGroupKeyFmt, userId))
}
// getUserQuotaCache gets user quota from cache
func getUserQuotaCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
quotaStr, err := common.RedisGet(fmt.Sprintf(constant.UserQuotaKeyFmt, userId))
if err != nil {
return 0, err
}
return strconv.Atoi(quotaStr)
}
// getUserStatusCache gets user status from cache
func getUserStatusCache(userId int) (int, error) {
if !common.RedisEnabled {
return 0, nil
}
statusStr, err := common.RedisGet(fmt.Sprintf(constant.UserEnabledKeyFmt, userId))
if err != nil {
return 0, err
}
return strconv.Atoi(statusStr)
}
// getUserNameCache gets username from cache
func getUserNameCache(userId int) (string, error) {
if !common.RedisEnabled {
return "", nil
}
return common.RedisGet(fmt.Sprintf(constant.UserUsernameKeyFmt, userId))
}
// getUserCache gets complete user cache
func getUserCache(userId int) (*userCache, error) {
if !common.RedisEnabled {
return nil, nil
}
group, err := getUserGroupCache(userId)
if err != nil {
return nil, fmt.Errorf("get group cache: %w", err)
}
quota, err := getUserQuotaCache(userId)
if err != nil {
return nil, fmt.Errorf("get quota cache: %w", err)
}
status, err := getUserStatusCache(userId)
if err != nil {
return nil, fmt.Errorf("get status cache: %w", err)
}
username, err := getUserNameCache(userId)
if err != nil {
return nil, fmt.Errorf("get username cache: %w", err)
}
return &userCache{
Id: userId,
Group: group,
Quota: quota,
Status: status,
Username: username,
}, nil
}
// Add atomic quota operations
func cacheIncrUserQuota(userId int, delta int64) error {
if !common.RedisEnabled {
return nil
}
key := fmt.Sprintf(constant.UserQuotaKeyFmt, userId)
return common.RedisIncr(key, delta)
}
func cacheDecrUserQuota(userId int, delta int64) error {
return cacheIncrUserQuota(userId, -delta)
}

View File

@@ -88,3 +88,7 @@ func RecordExist(err error) (bool, error) {
}
return false, err
}
func shouldUpdateRedis(fromDB bool, err error) bool {
return common.RedisEnabled && fromDB && err == nil
}

View File

@@ -225,9 +225,12 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
claudeMediaMessage.Source.MediaType = fileData.MimeType
claudeMediaMessage.Source.Data = fileData.Base64Data
} else {
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil {

View File

@@ -4,7 +4,7 @@ type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
}
@@ -18,16 +18,39 @@ type FunctionCall struct {
Arguments any `json:"args"`
}
type GeminiFunctionResponseContent struct {
Name string `json:"name"`
Content any `json:"content"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response any `json:"response"`
Name string `json:"name"`
Response GeminiFunctionResponseContent `json:"response"`
}
type GeminiPartExecutableCode struct {
Language string `json:"language,omitempty"`
Code string `json:"code,omitempty"`
}
type GeminiPartCodeExecutionResult struct {
Outcome string `json:"outcome,omitempty"`
Output string `json:"output,omitempty"`
}
type GeminiFileData struct {
MimeType string `json:"mimeType,omitempty"`
FileUri string `json:"fileUri,omitempty"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
}
type GeminiChatContent struct {
@@ -40,9 +63,11 @@ type GeminiChatSafetySettings struct {
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
GoogleSearch any `json:"googleSearch,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
type GeminiChatTool struct {
GoogleSearch any `json:"googleSearch,omitempty"`
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
@@ -54,11 +79,12 @@ type GeminiChatGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
FinishReason *string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}

View File

@@ -12,12 +12,14 @@ import (
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
@@ -46,16 +48,24 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
Seed: int64(textRequest.Seed),
},
}
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
continue
}
if tool.Function.Name == "codeExecution" {
codeExecution = true
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
@@ -68,25 +78,32 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
functions = append(functions, tool.Function)
}
if len(functions) > 0 {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: functions,
},
}
if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTools{
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
FunctionDeclarations: functions,
})
}
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
geminiRequest.Tools = []GeminiChatTool{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
@@ -96,20 +113,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
}
tool_call_ids := make(map[string]string)
var system_content []string
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
if message.Role == "system" {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
},
}
system_content = append(system_content, message.StringContent())
continue
} else if message.Role == "tool" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "user",
})
@@ -121,9 +132,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
content := common.StrToMap(message.StringContent())
functionResp := &FunctionResponse{
Name: name,
Response: common.StrToMap(message.StringContent()),
Name: name,
Response: GeminiFunctionResponseContent{
Name: name,
Content: content,
},
}
if content == nil {
functionResp.Response.Content = message.StringContent()
}
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
@@ -134,57 +152,68 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
content := GeminiChatContent{
Role: message.Role,
}
isToolCall := false
// isToolCall := false
if message.ToolCalls != nil {
message.Role = "model"
isToolCall = true
// message.Role = "model"
// isToolCall = true
for _, call := range message.ParseToolCalls() {
args := map[string]interface{}{}
if call.Function.Arguments != "" {
if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
FunctionName: call.Function.Name,
Arguments: call.Function.Parameters,
Arguments: args,
},
}
parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
}
}
if !isToolCall {
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
if part.Text == "" {
continue
}
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: fileData.MimeType,
Data: fileData.Base64Data,
},
})
} else {
format, base64String, err := service.DecodeBase64FileData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
}
}
}
@@ -197,6 +226,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
}
if len(system_content) > 0 {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
},
}
}
return &geminiRequest, nil
}
@@ -209,12 +249,12 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
if !ok || len(v) == 0 {
return schema
}
// 删除所有的title字段
delete(v, "title")
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}
delete(v, "title")
switch v["type"] {
case "object":
delete(v, "additionalProperties")
@@ -240,20 +280,85 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
return v
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
func unescapeString(s string) (string, error) {
var result []rune
escaped := false
i := 0
for i < len(s) {
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
if r == utf8.RuneError {
return "", fmt.Errorf("invalid UTF-8 encoding")
}
if escaped {
// 如果是转义符后的字符,检查其类型
switch r {
case '"':
result = append(result, '"')
case '\\':
result = append(result, '\\')
case '/':
result = append(result, '/')
case 'b':
result = append(result, '\b')
case 'f':
result = append(result, '\f')
case 'n':
result = append(result, '\n')
case 'r':
result = append(result, '\r')
case 't':
result = append(result, '\t')
case '\'':
result = append(result, '\'')
default:
// 如果遇到一个非法的转义字符,直接按原样输出
result = append(result, '\\', r)
}
escaped = false
} else {
if r == '\\' {
escaped = true // 记录反斜杠作为转义符
} else {
result = append(result, r)
}
}
i += size // 移动到下一个字符
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
return string(result), nil
}
func unescapeMapOrSlice(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
for k, val := range v {
v[k] = unescapeMapOrSlice(val)
}
case []interface{}:
for i, val := range v {
v[i] = unescapeMapOrSlice(val)
}
case string:
if unescaped, err := unescapeString(v); err != nil {
return v
} else {
return unescaped
}
}
return ""
return data
}
func getToolCall(item *GeminiPart) *dto.ToolCall {
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
} else {
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
}
if err != nil {
//common.SysError("getToolCall failed: " + err.Error())
return nil
}
return &dto.ToolCall{
@@ -266,30 +371,6 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
}
}
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
// var toolCalls []dto.ToolCall
// item := candidate.Content.Parts[index]
// if item.FunctionCall == nil {
// return toolCalls
// }
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
// if err != nil {
// //common.SysError("getToolCalls failed: " + err.Error())
// return toolCalls
// }
// toolCall := dto.ToolCall{
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
// Type: "function",
// Function: dto.FunctionCall{
// Arguments: string(argsBytes),
// Name: item.FunctionCall.FunctionName,
// },
// }
// toolCalls = append(toolCalls, toolCall)
// return toolCalls
// }
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -298,11 +379,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
is_tool_call := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: i,
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
Content: content,
@@ -319,48 +399,107 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text)
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if len(tool_calls) > 0 {
choice.Message.SetToolCalls(tool_calls)
is_tool_call = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
choice.Message.SetToolCalls(tool_calls)
}
if candidate.FinishReason != nil {
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = constant.FinishReasonLength
default:
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if is_tool_call {
choice.FinishReason = constant.FinishReasonToolCalls
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
is_stop := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
is_stop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
},
}
var texts []string
var tool_calls []dto.ToolCall
for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.FunctionCall != nil {
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text)
isTools := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = &constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = &constant.FinishReasonLength
default:
choice.FinishReason = &constant.FinishReasonContentFilter
}
}
if len(texts) > 0 {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if len(tool_calls) > 0 {
choice.Delta.ToolCalls = tool_calls
choice.Delta.SetContentString(strings.Join(texts, "\n"))
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
choices = append(choices, choice)
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
response.Choices = choices
return &response, is_stop
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
// responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
@@ -384,14 +523,11 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
responseText += response.Choices[0].Delta.GetContentString()
// responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -400,12 +536,17 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
if err != nil {
common.LogError(c, err.Error())
}
if is_stop {
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
}
}
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
var response *dto.ChatCompletionsStreamResponse
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage {
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)

View File

@@ -106,7 +106,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI {
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1") {

View File

@@ -109,7 +109,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
info.SupportStreamOptions = true
}
return info

View File

@@ -74,30 +74,16 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return service.OpenAIErrorWrapperLocal(errors.New(fmt.Sprintf("audio pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota)), "insufficient_user_quota", http.StatusBadRequest)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
preConsumedQuota, userQuota, openaiErr = preConsumeQuota(c, preConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {

View File

@@ -99,8 +99,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
modelPrice = 0.0025 * modelRatio
}
groupRatio := common.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
sizeRatio := 1.0
// Size

View File

@@ -168,9 +168,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
@@ -194,11 +194,11 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
//err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
@@ -474,9 +474,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
userQuota, err := model.GetUserQuota(userId, false)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
@@ -500,14 +500,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quota, 0, true)
err := model.PostConsumeQuota(relayInfo, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)

View File

@@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
@@ -108,10 +108,18 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
// 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens = value.(int)
relayInfo.PromptTokens = promptTokens
} else {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
c.Set("prompt_tokens", promptTokens)
}
if !getModelPriceSuccess {
@@ -223,7 +231,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
var err error
switch info.RelayMode {
case relayconstant.RelayModeChatCompletions:
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
case relayconstant.RelayModeCompletions:
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
case relayconstant.RelayModeModerations:
@@ -255,7 +263,7 @@ func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycom
// 预扣费并返回用户剩余配额
func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, int, *dto.OpenAIErrorWithStatusCode) {
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
}
@@ -265,10 +273,6 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if userQuota-preConsumedQuota < 0 {
return 0, 0, service.OpenAIErrorWrapperLocal(fmt.Errorf("chat pre-consumed quota failed, user quota: %d, need quota: %d", userQuota, preConsumedQuota), "insufficient_user_quota", http.StatusBadRequest)
}
err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
@@ -286,11 +290,16 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
}
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
err = model.PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, 0, service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
}
return preConsumedQuota, userQuota, nil
}
@@ -300,7 +309,7 @@ func returnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, us
go func() {
relayInfoCopy := *relayInfo
err := model.PostConsumeTokenQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
err := model.PostConsumeQuota(&relayInfoCopy, userQuota, -preConsumedQuota, 0, false)
if err != nil {
common.SysError("error return pre-consumed quota: " + err.Error())
}
@@ -358,15 +367,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
//}
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64

View File

@@ -16,6 +16,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
)
/*
@@ -48,9 +49,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return
@@ -112,14 +113,10 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeTokenQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
err := model.PostConsumeQuota(relayInfo.ToRelayInfo(), userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
)
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64

View File

@@ -28,10 +28,10 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
userRoute := apiRouter.Group("/user")
{
@@ -99,7 +99,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -9,9 +9,12 @@ import (
"strings"
)
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
func DoDownloadRequest(originUrl string) (resp *http.Response, err error) {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
common.SysLog(fmt.Sprintf("downloading file from worker: %s", originUrl))
if !strings.HasPrefix(originUrl, "https") {
return nil, fmt.Errorf("only support https url")
}
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
@@ -20,7 +23,7 @@ func DoImageRequest(originUrl string) (resp *http.Response, err error) {
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
common.SysLog(fmt.Sprintf("downloading from origin: %s", originUrl))
return http.Get(originUrl)
}
}

39
service/file_decoder.go Normal file
View File

@@ -0,0 +1,39 @@
package service
import (
"encoding/base64"
"fmt"
"io"
"one-api/constant"
"one-api/dto"
)
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
resp, err := DoDownloadRequest(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// Always use LimitReader to prevent oversized downloads
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
if err != nil {
return nil, err
}
// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
}
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
return &dto.LocalFileData{
Base64Data: base64Data,
MimeType: resp.Header.Get("Content-Type"),
Size: int64(len(fileBytes)),
}, nil
}

View File

@@ -5,11 +5,12 @@ import (
"encoding/base64"
"errors"
"fmt"
"golang.org/x/image/webp"
"image"
"io"
"one-api/common"
"strings"
"golang.org/x/image/webp"
)
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
@@ -31,14 +32,39 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e
return config, format, base64String, err
}
func DecodeBase64FileData(base64String string) (string, string, error) {
var mimeType string
var idx int
idx = strings.Index(base64String, ",")
if idx == -1 {
_, file_type, base64, err := DecodeBase64ImageData(base64String)
return "image/" + file_type, base64, err
}
mimeType = base64String[:idx]
base64String = base64String[idx+1:]
idx = strings.Index(mimeType, ";")
if idx == -1 {
_, file_type, base64, err := DecodeBase64ImageData(base64String)
return "image/" + file_type, base64, err
}
mimeType = mimeType[:idx]
idx = strings.Index(mimeType, ":")
if idx == -1 {
_, file_type, base64, err := DecodeBase64ImageData(base64String)
return "image/" + file_type, base64, err
}
mimeType = mimeType[idx+1:]
return mimeType, base64String, nil
}
// GetImageFromUrl 获取图片的类型和base64编码的数据
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
resp, err := DoImageRequest(url)
resp, err := DoDownloadRequest(url)
if err != nil {
return
return "", "", err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return
return "", "", fmt.Errorf("invalid content type: %s, required image/*", resp.Header.Get("Content-Type"))
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
@@ -52,7 +78,7 @@ func GetImageFromUrl(url string) (mimeType string, data string, err error) {
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := DoImageRequest(imageUrl)
response, err := DoDownloadRequest(imageUrl)
if err != nil {
common.SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, "", err
@@ -64,6 +90,12 @@ func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
return image.Config{}, "", err
}
mimeType := response.Header.Get("Content-Type")
if !strings.HasPrefix(mimeType, "image/") {
return image.Config{}, "", fmt.Errorf("invalid content type: %s, required image/*", mimeType)
}
var readData []byte
for _, limit := range []int64{1024 * 8, 1024 * 24, 1024 * 64} {
common.SysLog(fmt.Sprintf("try to decode image config with limit: %d", limit))

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/setting"
"strings"
"time"
)
@@ -17,12 +18,12 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
if relayInfo.UsePrice {
return nil
}
userQuota, err := model.GetUserQuota(relayInfo.UserId)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return err
}
token, err := model.CacheGetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"))
token, err := model.GetTokenByKey(strings.TrimLeft(relayInfo.TokenKey, "sk-"), false)
if err != nil {
return err
}
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio
@@ -57,15 +58,11 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
return errors.New(fmt.Sprintf("令牌额度不足,剩余额度为 %d", token.RemainQuota))
}
err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false)
err = model.PostConsumeQuota(relayInfo, 0, quota, 0, false)
if err != nil {
return err
}
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
return err
}
return nil
}
@@ -119,7 +116,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
//}
//quotaDelta := quota - preConsumedQuota
//if quotaDelta != 0 {
// err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
// err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
// if err != nil {
// common.LogError(ctx, "error consuming token remain quota: "+err.Error())
// }
@@ -189,15 +186,11 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
} else {
quotaDelta := quota - preConsumedQuota
if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
err := model.PostConsumeQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
err := model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -19,42 +19,40 @@ import (
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken
var o200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
defaultTokenEncoder = cl100TokenEncoder
o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
tokenEncoderMap[model] = cl100TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = cl200kTokenEncoder
tokenEncoderMap[model] = o200kTokenEncoder
} else {
tokenEncoderMap[model] = gpt4TokenEncoder
tokenEncoderMap[model] = defaultTokenEncoder
}
} else if strings.HasPrefix(model, "o1") {
tokenEncoderMap[model] = o200kTokenEncoder
} else {
tokenEncoderMap[model] = nil
tokenEncoderMap[model] = defaultTokenEncoder
}
}
common.SysLog("token encoders initialized")
}
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
return cl200kTokenEncoder
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
return o200kTokenEncoder
}
return defaultTokenEncoder
}
@@ -82,7 +80,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
return len(tokenEncoder.Encode(text, nil, nil))
}
func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
baseTokens := 85
if model == "glm-4v" {
return 1047, nil
@@ -92,11 +90,7 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
}
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream {
return 1000, nil
}
// 是否统计图片token
if !constant.GetMediaToken {
return 1000, nil
return 256, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
@@ -108,6 +102,13 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
tileTokens = 5667
baseTokens = 2833
}
// 是否统计图片token
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
if info.ChannelType == common.ChannelTypeGemini || info.ChannelType == common.ChannelTypeVertexAi || info.ChannelType == common.ChannelTypeAnthropic {
return 3 * baseTokens, nil
}
var config image.Config
var err error
var format string
@@ -157,9 +158,9 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
return tiles*tileTokens + baseTokens, nil
}
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, error) {
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
tkm := 0
msgTokens, err := CountTokenMessages(request.Messages, model, request.Stream)
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
if err != nil {
return 0, err
}
@@ -181,7 +182,7 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens, err := CountTokenInput(countStr, model)
toolTokens, err := CountTokenInput(countStr, request.Model)
if err != nil {
return 0, err
}
@@ -258,7 +259,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
return textToken, audioToken, nil
}
func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
@@ -292,7 +293,7 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
imageTokenNum, err := getImageToken(&imageUrl, model, stream)
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream)
if err != nil {
return 0, err
}

View File

@@ -1,33 +1,47 @@
package common
package setting
import (
"encoding/json"
"errors"
"one-api/common"
)
var GroupRatio = map[string]float64{
var groupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GetGroupRatioCopy() map[string]float64 {
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
}
return groupRatioCopy
}
func ContainsGroupRatio(name string) bool {
_, ok := groupRatio[name]
return ok
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &groupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
ratio, ok := groupRatio[name]
if !ok {
SysError("group ratio not found: " + name)
common.SysError("group ratio not found: " + name)
return 1
}
return ratio

View File

@@ -0,0 +1,52 @@
package setting
import (
"encoding/json"
"one-api/common"
)
var userUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func GetUserUsableGroupsCopy() map[string]string {
copyUserUsableGroups := make(map[string]string)
for k, v := range userUsableGroups {
copyUserUsableGroups[k] = v
}
return copyUserUsableGroups
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(userUsableGroups)
if err != nil {
common.SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
userUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &userUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
groupsCopy := GetUserUsableGroupsCopy()
if userGroup == "" {
if _, ok := groupsCopy["default"]; !ok {
groupsCopy["default"] = "default"
}
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := groupsCopy[userGroup]; !ok {
groupsCopy[userGroup] = "用户分组"
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return groupsCopy
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := userUsableGroups[groupName]
return ok
}

View File

@@ -1,5 +1,5 @@
<!doctype html>
<html lang="en">
<html lang="zh">
<head>
<meta charset="utf-8" />
<link rel="icon" href="/logo.png" />

5127
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -162,9 +162,15 @@ const ChannelsTable = () => {
return (
<div>
<Space spacing={2}>
{text?.split(',').map((item, index) => {
return renderGroup(item);
})}
{text?.split(',')
.sort((a, b) => {
if (a === 'default') return -1;
if (b === 'default') return 1;
return a.localeCompare(b);
})
.map((item, index) => {
return renderGroup(item);
})}
</Space>
</div>
);
@@ -507,6 +513,8 @@ const ChannelsTable = () => {
const [selectedChannels, setSelectedChannels] = useState([]);
const [showEditPriority, setShowEditPriority] = useState(false);
const [enableTagMode, setEnableTagMode] = useState(false);
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
const [batchSetTagValue, setBatchSetTagValue] = useState('');
const removeRecord = (record) => {
@@ -968,6 +976,29 @@ const ChannelsTable = () => {
}
};
const batchSetChannelTag = async () => {
if (selectedChannels.length === 0) {
showError(t('请先选择要设置标签的渠道!'));
return;
}
if (batchSetTagValue === '') {
showError(t('标签不能为空!'));
return;
}
let ids = selectedChannels.map(channel => channel.id);
const res = await API.post('/api/channel/batch/tag', {
ids: ids,
tag: batchSetTagValue === '' ? null : batchSetTagValue
});
if (res.data.success) {
showSuccess(t('已为 ${count} 个渠道设置标签!').replace('${count}', res.data.data));
await refresh();
setShowBatchSetTag(false);
} else {
showError(res.data.message);
}
};
return (
<>
<EditTagModal
@@ -1115,11 +1146,11 @@ const ChannelsTable = () => {
</div>
<div style={{ marginTop: 20 }}>
<Space>
<Typography.Text strong>{t('开启批量删除')}</Typography.Text>
<Typography.Text strong>{t('开启批量操作')}</Typography.Text>
<Switch
label={t('开启批量删除')}
label={t('开启批量操作')}
uncheckedText={t('关')}
aria-label={t('是否开启批量删除')}
aria-label={t('是否开启批量操作')}
onChange={(v) => {
setEnableBatchDelete(v);
}}
@@ -1167,7 +1198,17 @@ const ChannelsTable = () => {
loadChannels(0, pageSize, idSort, v);
}}
/>
<Button
disabled={!enableBatchDelete}
theme="light"
type="primary"
style={{ marginRight: 8 }}
onClick={() => setShowBatchSetTag(true)}
>
{t('批量设置标签')}
</Button>
</Space>
</div>
@@ -1201,6 +1242,23 @@ const ChannelsTable = () => {
: null
}
/>
<Modal
title={t('批量设置标签')}
visible={showBatchSetTag}
onOk={batchSetChannelTag}
onCancel={() => setShowBatchSetTag(false)}
maskClosable={false}
centered={true}
>
<div style={{ marginBottom: 20 }}>
<Typography.Text>{t('请输入要设置的标签名称')}</Typography.Text>
</div>
<Input
placeholder={t('请输入标签名称')}
value={batchSetTagValue}
onChange={(v) => setBatchSetTagValue(v)}
/>
</Modal>
</>
);
};

View File

@@ -199,7 +199,7 @@ const HeaderBar = () => {
</Dropdown.Menu>
}
>
<Nav.Item itemKey={'new-year'} text={'🏮'} />
<Nav.Item itemKey={'new-year'} text={'🎉'} />
</Dropdown>
)}
{/* <Nav.Item itemKey={'about'} icon={<IconHelpCircle />} /> */}

View File

@@ -185,7 +185,10 @@ const LogsTable = () => {
size='small'
color={stringToColor(text)}
style={{ marginRight: 4 }}
onClick={() => showUserInfo(record.user_id)}
onClick={(event) => {
event.stopPropagation();
showUserInfo(record.user_id)
}}
>
{typeof text === 'string' && text.slice(0, 1)}
</Avatar>
@@ -205,8 +208,9 @@ const LogsTable = () => {
<Tag
color='grey'
size='large'
onClick={() => {
copyText(text);
onClick={(event) => {
//cancel the row click event
copyText(event, text);
}}
>
{' '}
@@ -265,8 +269,8 @@ const LogsTable = () => {
<Tag
color={stringToColor(text)}
size='large'
onClick={() => {
copyText(text);
onClick={(event) => {
copyText(event, text);
}}
>
{' '}
@@ -445,17 +449,11 @@ const LogsTable = () => {
});
const handleInputChange = (value, name) => {
if (value && (name === 'start_timestamp' || name === 'end_timestamp')) {
// 确保日期值是有效的
const dateValue = typeof value === 'string' ? value : timestamp2string(value);
setInputs(inputs => ({ ...inputs, [name]: dateValue }));
} else {
setInputs(inputs => ({ ...inputs, [name]: value }));
}
setInputs(inputs => ({ ...inputs, [name]: value }));
};
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(3) / 1000;
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
url = encodeURI(url);
@@ -524,7 +522,7 @@ const LogsTable = () => {
let expandDatesLocal = {};
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = i;
logs[i].key = logs[i].id;
let other = getLogOther(logs[i].other);
let expandDataLocal = [];
if (isAdmin()) {
@@ -656,11 +654,12 @@ const LogsTable = () => {
await loadLogs(activePage, pageSize, logType);
};
const copyText = async (text) => {
const copyText = async (e, text) => {
e.stopPropagation();
if (await copy(text)) {
showSuccess('已复制:' + text);
} else {
Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text });
Modal.error({ title: t('无法复制到剪贴板,请手动复制'), content: text });
}
};
@@ -826,6 +825,12 @@ const LogsTable = () => {
dataSource={logs}
rowKey="key"
pagination={{
formatPageText: (page) =>
t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
start: page.currentStart,
end: page.currentEnd,
total: logs.length
}),
currentPage: activePage,
pageSize: pageSize,
total: logCount,

View File

@@ -81,41 +81,24 @@ const ModelPricing = () => {
}
function renderAvailable(available) {
return available ? (
return (
<Popover
content={
<div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
content={
<div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconVerify style={{ color: 'green' }} size="large" />
<IconVerify style={{ color: 'green' }} size="large" />
</Popover>
) : (
<Popover
content={
<div style={{ padding: 8 }}>{t('您的分组无权使用该模型')}</div>
}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconUploadError style={{ color: '#FFA54F' }} size="large" />
</Popover>
);
)
}
const columns = [
@@ -162,36 +145,39 @@ const ModelPricing = () => {
title: t('可用分组'),
dataIndex: 'enable_groups',
render: (text, record, index) => {
// enable_groups is a string array
return (
<Space>
{text.map((group) => {
if (group === selectedGroup) {
return (
<Tag
color='blue'
size='large'
prefixIcon={<IconVerify />}
>
{group}
</Tag>
);
} else {
return (
<Tag
color='blue'
size='large'
onClick={() => {
setSelectedGroup(group);
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
group: group,
ratio: groupRatio[group]
}));
}}
>
{group}
</Tag>
);
if (usableGroup[group]) {
if (group === selectedGroup) {
return (
<Tag
color='blue'
size='large'
prefixIcon={<IconVerify />}
>
{group}
</Tag>
);
} else {
return (
<Tag
color='blue'
size='large'
onClick={() => {
setSelectedGroup(group);
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
group: group,
ratio: groupRatio[group]
}));
}}
>
{group}
</Tag>
);
}
}
})}
</Space>
@@ -275,6 +261,7 @@ const ModelPricing = () => {
const [loading, setLoading] = useState(true);
const [userState, userDispatch] = useContext(UserContext);
const [groupRatio, setGroupRatio] = useState({});
const [usableGroup, setUsableGroup] = useState({});
const setModelsFormat = (models, groupRatio) => {
for (let i = 0; i < models.length; i++) {
@@ -309,9 +296,10 @@ const ModelPricing = () => {
let url = '';
url = `/api/pricing`;
const res = await API.get(url);
const { success, message, data, group_ratio } = res.data;
const { success, message, data, group_ratio, usable_group } = res.data;
if (success) {
setGroupRatio(group_ratio);
setUsableGroup(usable_group);
setSelectedGroup(userState.user ? userState.user.group : 'default')
setModelsFormat(data, group_ratio);
} else {

View File

@@ -146,8 +146,9 @@ const PersonalSetting = () => {
let res = await API.get(`/api/user/models`);
const {success, message, data} = res.data;
if (success) {
setModels(data);
console.log(data);
if (data != null) {
setModels(data);
}
} else {
showError(message);
}

View File

@@ -178,6 +178,7 @@ const RedemptionsTable = () => {
const [searching, setSearching] = useState(false);
const [tokenCount, setTokenCount] = useState(ITEMS_PER_PAGE);
const [selectedKeys, setSelectedKeys] = useState([]);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [editingRedemption, setEditingRedemption] = useState({
id: undefined,
});
@@ -187,40 +188,20 @@ const RedemptionsTable = () => {
setShowEdit(false);
};
// const setCount = (data) => {
// if (data.length >= (activePage) * ITEMS_PER_PAGE) {
// setTokenCount(data.length + 1);
// } else {
// setTokenCount(data.length);
// }
// }
const setRedemptionFormat = (redeptions) => {
// for (let i = 0; i < redeptions.length; i++) {
// redeptions[i].key = '' + redeptions[i].id;
// }
// data.key = '' + data.id
setRedemptions(redeptions);
if (redeptions.length >= activePage * ITEMS_PER_PAGE) {
setTokenCount(redeptions.length + 1);
} else {
setTokenCount(redeptions.length);
}
};
const loadRedemptions = async (startIdx) => {
const res = await API.get(`/api/redemption/?p=${startIdx}`);
const loadRedemptions = async (startIdx, pageSize) => {
const res = await API.get(`/api/redemption/?p=${startIdx}&page_size=${pageSize}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setRedemptionFormat(data);
} else {
let newRedemptions = redemptions;
newRedemptions.push(...data);
setRedemptionFormat(newRedemptions);
}
const newPageData = data.items;
setActivePage(data.page);
setTokenCount(data.total);
setRedemptionFormat(newPageData);
} else {
showError(message);
showError(message);
}
setLoading(false);
};
@@ -248,16 +229,15 @@ const RedemptionsTable = () => {
const onPaginationChange = (e, { activePage }) => {
(async () => {
if (activePage === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadRedemptions(activePage - 1);
if (activePage === Math.ceil(redemptions.length / pageSize) + 1) {
await loadRedemptions(activePage - 1, pageSize);
}
setActivePage(activePage);
})();
};
useEffect(() => {
loadRedemptions(0)
loadRedemptions(0, pageSize)
.then()
.catch((reason) => {
showError(reason);
@@ -265,7 +245,7 @@ const RedemptionsTable = () => {
}, []);
const refresh = async () => {
await loadRedemptions(activePage - 1);
await loadRedemptions(activePage - 1, pageSize);
};
const manageRedemption = async (id, action, record) => {
@@ -300,23 +280,21 @@ const RedemptionsTable = () => {
}
};
const searchRedemptions = async () => {
const searchRedemptions = async (keyword, page, pageSize) => {
if (searchKeyword === '') {
// if keyword is blank, load files instead.
await loadRedemptions(0);
setActivePage(1);
return;
await loadRedemptions(page, pageSize);
return;
}
setSearching(true);
const res = await API.get(
`/api/redemption/search?keyword=${searchKeyword}`,
);
const res = await API.get(`/api/redemption/search?keyword=${keyword}&p=${page}&page_size=${pageSize}`);
const { success, message, data } = res.data;
if (success) {
setRedemptions(data);
setActivePage(1);
const newPageData = data.items;
setActivePage(data.page);
setTokenCount(data.total);
setRedemptionFormat(newPageData);
} else {
showError(message);
showError(message);
}
setSearching(false);
};
@@ -341,16 +319,14 @@ const RedemptionsTable = () => {
const handlePageChange = (page) => {
setActivePage(page);
if (page === Math.ceil(redemptions.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
loadRedemptions(page - 1).then((r) => {});
if (searchKeyword === '') {
loadRedemptions(page, pageSize).then();
} else {
searchRedemptions(searchKeyword, page, pageSize).then();
}
};
let pageData = redemptions.slice(
(activePage - 1) * ITEMS_PER_PAGE,
activePage * ITEMS_PER_PAGE,
);
let pageData = redemptions;
const rowSelection = {
onSelect: (record, selected) => {},
onSelectAll: (selected, selectedRows) => {},
@@ -379,7 +355,9 @@ const RedemptionsTable = () => {
visiable={showEdit}
handleClose={closeEdit}
></EditRedemption>
<Form onSubmit={searchRedemptions}>
<Form onSubmit={()=> {
searchRedemptions(searchKeyword, activePage, pageSize).then();
}}>
<Form.Input
label={t('搜索关键字')}
field='keyword'
@@ -431,20 +409,25 @@ const RedemptionsTable = () => {
dataSource={pageData}
pagination={{
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
pageSize: pageSize,
total: tokenCount,
// showSizeChanger: true,
// pageSizeOptions: [10, 20, 50, 100],
showSizeChanger: true,
pageSizeOpts: [10, 20, 50, 100],
formatPageText: (page) =>
t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
start: page.currentStart,
end: page.currentEnd,
total: redemptions.length
total: tokenCount
}),
// onPageSizeChange: (size) => {
// setPageSize(size);
// setActivePage(1);
// },
onPageSizeChange: (size) => {
setPageSize(size);
setActivePage(1);
if (searchKeyword === '') {
loadRedemptions(1, size).then();
} else {
searchRedemptions(searchKeyword, 1, size).then();
}
},
onPageChange: handlePageChange,
}}
loading={loading}

View File

@@ -231,6 +231,7 @@ const UsersTable = () => {
const [users, setUsers] = useState([]);
const [loading, setLoading] = useState(true);
const [activePage, setActivePage] = useState(1);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [searchGroup, setSearchGroup] = useState('');
@@ -242,14 +243,6 @@ const UsersTable = () => {
id: undefined,
});
const setCount = (data) => {
if (data.length >= activePage * ITEMS_PER_PAGE) {
setUserCount(data.length + 1);
} else {
setUserCount(data.length);
}
};
const removeRecord = (key) => {
let newDataSource = [...users];
if (key != null) {
@@ -263,37 +256,30 @@ const UsersTable = () => {
}
};
const loadUsers = async (startIdx) => {
const res = await API.get(`/api/user/?p=${startIdx}`);
const setUserFormat = (users) => {
for (let i = 0; i < users.length; i++) {
users[i].key = users[i].id;
}
setUsers(users);
}
const loadUsers = async (startIdx, pageSize) => {
const res = await API.get(`/api/user/?p=${startIdx}&page_size=${pageSize}`);
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setUsers(data);
setCount(data);
} else {
let newUsers = users;
newUsers.push(...data);
setUsers(newUsers);
setCount(newUsers);
}
const newPageData = data.items;
setActivePage(data.page);
setUserCount(data.total);
setUserFormat(newPageData);
} else {
showError(message);
}
setLoading(false);
};
const onPaginationChange = (e, { activePage }) => {
(async () => {
if (activePage === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
await loadUsers(activePage - 1);
}
setActivePage(activePage);
})();
};
useEffect(() => {
loadUsers(0)
loadUsers(0, pageSize)
.then()
.catch((reason) => {
showError(reason);
@@ -341,21 +327,22 @@ const UsersTable = () => {
}
};
const searchUsers = async (searchKeyword, searchGroup) => {
const searchUsers = async (startIdx, pageSize, searchKeyword, searchGroup) => {
if (searchKeyword === '' && searchGroup === '') {
// if keyword is blank, load files instead.
await loadUsers(0);
setActivePage(1);
return;
// if keyword is blank, load files instead.
await loadUsers(startIdx, pageSize);
return;
}
setSearching(true);
const res = await API.get(`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}`);
const res = await API.get(`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}&p=${startIdx}&page_size=${pageSize}`);
const { success, message, data } = res.data;
if (success) {
setUsers(data);
setActivePage(1);
const newPageData = data.items;
setActivePage(data.page);
setUserCount(data.total);
setUserFormat(newPageData);
} else {
showError(message);
showError(message);
}
setSearching(false);
};
@@ -364,33 +351,15 @@ const UsersTable = () => {
setSearchKeyword(value.trim());
};
const sortUser = (key) => {
if (users.length === 0) return;
setLoading(true);
let sortedUsers = [...users];
sortedUsers.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]);
});
if (sortedUsers[0].id === users[0].id) {
sortedUsers.reverse();
}
setUsers(sortedUsers);
setLoading(false);
};
const handlePageChange = (page) => {
setActivePage(page);
if (page === Math.ceil(users.length / ITEMS_PER_PAGE) + 1) {
// In this case we have to load more data and then append them.
loadUsers(page - 1).then((r) => {});
if (searchKeyword === '' && searchGroup === '') {
loadUsers(page, pageSize).then();
} else {
searchUsers(page, pageSize, searchKeyword, searchGroup).then();
}
};
const pageData = users.slice(
(activePage - 1) * ITEMS_PER_PAGE,
activePage * ITEMS_PER_PAGE,
);
const closeAddUser = () => {
setShowAddUser(false);
};
@@ -403,10 +372,11 @@ const UsersTable = () => {
};
const refresh = async () => {
setActivePage(1)
if (searchKeyword === '') {
await loadUsers(activePage - 1);
await loadUsers(activePage, pageSize);
} else {
await searchUsers();
await searchUsers(searchKeyword, searchGroup);
}
};
@@ -429,6 +399,17 @@ const UsersTable = () => {
}
};
const handlePageSizeChange = async (size) => {
localStorage.setItem('page-size', size + '');
setPageSize(size);
setActivePage(1);
loadUsers(activePage, size)
.then()
.catch((reason) => {
showError(reason);
});
};
return (
<>
<AddUser
@@ -444,29 +425,32 @@ const UsersTable = () => {
></EditUser>
<Form
onSubmit={() => {
searchUsers(searchKeyword, searchGroup);
searchUsers(activePage, pageSize, searchKeyword, searchGroup);
}}
labelPosition='left'
>
<div style={{ display: 'flex' }}>
<Space>
<Form.Input
label={t('搜索关键字')}
icon='search'
field='keyword'
iconPosition='left'
placeholder={t('搜索用户的 ID用户名显示名称以及邮箱地址 ...')}
value={searchKeyword}
loading={searching}
onChange={(value) => handleKeywordChange(value)}
/>
<Tooltip content={t('支持搜索用户的 ID、用户名、显示名称和邮箱地址')}>
<Form.Input
label={t('搜索关键字')}
icon='search'
field='keyword'
iconPosition='left'
placeholder={t('搜索关键字')}
value={searchKeyword}
loading={searching}
onChange={(value) => handleKeywordChange(value)}
/>
</Tooltip>
<Form.Select
field='group'
label={t('分组')}
optionList={groupOptions}
onChange={(value) => {
setSearchGroup(value);
searchUsers(searchKeyword, value);
searchUsers(activePage, pageSize, searchKeyword, value);
}}
/>
<Button
@@ -492,7 +476,7 @@ const UsersTable = () => {
<Table
columns={columns}
dataSource={pageData}
dataSource={users}
pagination={{
formatPageText: (page) =>
t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
@@ -501,9 +485,13 @@ const UsersTable = () => {
total: users.length
}),
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
pageSize: pageSize,
total: userCount,
pageSizeOpts: [10, 20, 50, 100],
showSizeChanger: true,
onPageSizeChange: (size) => {
handlePageSizeChange(size);
},
onPageChange: handlePageChange,
}}
loading={loading}

View File

@@ -1,5 +1,6 @@
import i18next from 'i18next';
import { Tag } from '@douyinfe/semi-ui';
import { Modal, Tag, Typography } from '@douyinfe/semi-ui';
import { copy, showSuccess } from './utils.js';
export function renderText(text, limit) {
if (text.length > limit) {
@@ -38,6 +39,14 @@ export function renderGroup(group) {
size='large'
color={tagColors[group] || stringToColor(group)}
key={group}
onClick={async (event) => {
event.stopPropagation();
if (await copy(group)) {
showSuccess(i18next.t('已复制:') + group);
} else {
Modal.error({ title: t('无法复制到剪贴板,请手动复制'), content: group });
}
}}
>
{group}
</Tag>
@@ -46,6 +55,81 @@ export function renderGroup(group) {
);
}
export function renderRatio(ratio) {
let color = 'green';
if (ratio > 5) {
color = 'red';
} else if (ratio > 3) {
color = 'orange';
} else if (ratio > 1) {
color = 'blue';
}
return <Tag color={color}>{ratio}x {i18next.t('倍率')}</Tag>;
}
export const renderGroupOption = (item) => {
const {
disabled,
selected,
label,
value,
focused,
className,
style,
onMouseEnter,
onClick,
empty,
emptyContent,
...rest
} = item;
const baseStyle = {
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
padding: '8px 16px',
cursor: disabled ? 'not-allowed' : 'pointer',
backgroundColor: focused ? 'var(--semi-color-fill-0)' : 'transparent',
opacity: disabled ? 0.5 : 1,
...(selected && {
backgroundColor: 'var(--semi-color-primary-light-default)',
}),
'&:hover': {
backgroundColor: !disabled && 'var(--semi-color-fill-1)'
}
};
const handleClick = () => {
if (!disabled && onClick) {
onClick();
}
};
const handleMouseEnter = (e) => {
if (!disabled && onMouseEnter) {
onMouseEnter(e);
}
};
return (
<div
style={baseStyle}
onClick={handleClick}
onMouseEnter={handleMouseEnter}
>
<div style={{ display: 'flex', flexDirection: 'column', gap: '4px' }}>
<Typography.Text strong type={disabled ? 'tertiary' : undefined}>
{value}
</Typography.Text>
<Typography.Text type="secondary" size="small">
{label}
</Typography.Text>
</div>
{item.ratio && renderRatio(item.ratio)}
</div>
);
};
export function renderNumber(num) {
if (num >= 1000000000) {
return (num / 1000000000).toFixed(1) + 'B';
@@ -59,6 +143,9 @@ export function renderNumber(num) {
}
export function renderQuotaNumberWithDigit(num, digits = 2) {
if (typeof num !== 'number' || isNaN(num)) {
return 0;
}
let displayInCurrency = localStorage.getItem('display_in_currency');
num = num.toFixed(digits);
if (displayInCurrency) {
@@ -340,7 +427,7 @@ export const modelColorMap = {
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃<EFBFBD><EFBFBD><EFBFBD>
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
'gpt-4': 'rgb(135,206,235)', // 天蓝色
// 'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
@@ -363,7 +450,7 @@ export const modelColorMap = {
'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红
'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别)
'text-moderation-latest': 'rgb(255,130,171)', // 强粉色
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(<EFBFBD><EFBFBD><EFBFBD>Babbage相同表示同一类功能
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色(Babbage相同表示同一类功能
'tts-1': 'rgb(255,140,0)', // 深橙色
'tts-1-1106': 'rgb(255,165,0)', // 橙色
'tts-1-hd': 'rgb(255,215,0)', // 金色

View File

@@ -49,8 +49,18 @@ export async function copy(text) {
try {
await navigator.clipboard.writeText(text);
} catch (e) {
okay = false;
console.error(e);
try {
// 构建input 执行 复制命令
var _input = window.document.createElement("input");
_input.value = text;
window.document.body.appendChild(_input);
_input.select();
window.document.execCommand("Copy");
window.document.body.removeChild(_input);
} catch (e) {
okay = false;
console.error(e);
}
}
return okay;
}
@@ -180,6 +190,9 @@ export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') {
let month = (date.getMonth() + 1).toString();
let day = date.getDate().toString();
let hour = date.getHours().toString();
if (day === '24') {
console.log("timestamp", timestamp);
}
if (month.length === 1) {
month = '0' + month;
}

View File

@@ -546,8 +546,8 @@
"是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection",
"开启批量操作": "Enable batch selection",
"是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",

View File

@@ -430,6 +430,8 @@
"一小时后过期": "Expires after one hour",
"一分钟后过期": "Expires after one minute",
"创建新的令牌": "Create New Token",
"令牌分组,默认为用户的分组": "Token group, default is the your's group",
"IP白名单请勿过度信任此功能": "IP whitelist (do not overly trust this function)",
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
"设为无限额度": "Set to unlimited quota",
"更新令牌信息": "Update Token Information",
@@ -546,8 +548,8 @@
"是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection",
"开启批量操作": "Enable batch selection",
"是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
@@ -866,8 +868,8 @@
"请选择模式": "Please select mode",
"图片代理方式": "Picture agency method",
"用于替换 https://cdn.discordapp.com 的域名": "The domain name used to replace https://cdn.discordapp.com",
"一个月": "a month",
"一天": "one day",
"一个月": "A month",
"一天": "One day",
"令牌渠道分组选择": "Token channel grouping selection",
"只可使用对应分组包含的模型。": "Only models contained in the corresponding group can be used.",
"渠道分组": "Channel grouping",
@@ -876,7 +878,7 @@
"启用模型限制(非必要,不建议启用)": "Enable model restrictions (not necessary, not recommended)",
"秒": "Second",
"更新令牌后需等待几分钟生效": "It will take a few minutes to take effect after updating the token.",
"一小时": "one hour",
"一小时": "One hour",
"新建数量": "New quantity",
"加载失败,请稍后重试": "Loading failed, please try again later",
"未设置": "Not set",
@@ -1235,5 +1237,10 @@
"更多": "Expand more",
"个模型": "models",
"可用模型": "Available models",
"时间范围": "Time range"
"时间范围": "Time range",
"批量设置标签": "Batch set tag",
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name",
"支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address",
"已注销": "Logged out"
}

View File

@@ -143,8 +143,7 @@ const Detail = (props) => {
content: [
{
key: (datum) => datum['Model'],
value: (datum) =>
renderQuotaNumberWithDigit(parseFloat(datum['Usage']), 4),
value: (datum) => renderQuota(datum['rawQuota'] || 0, 4),
},
],
},
@@ -152,22 +151,28 @@ const Detail = (props) => {
content: [
{
key: (datum) => datum['Model'],
value: (datum) => datum['Usage'],
value: (datum) => datum['rawQuota'] || 0,
},
],
updateContent: (array) => {
array.sort((a, b) => b.value - a.value);
let sum = 0;
for (let i = 0; i < array.length; i++) {
sum += parseFloat(array[i].value);
array[i].value = renderQuotaNumberWithDigit(
parseFloat(array[i].value),
4,
);
if (array[i].key == "其他") {
continue;
}
let value = parseFloat(array[i].value);
if (isNaN(value)) {
value = 0;
}
if (array[i].datum && array[i].datum.TimeSum) {
sum = array[i].datum.TimeSum;
}
array[i].value = renderQuota(value, 4);
}
array.unshift({
key: t('总计'),
value: renderQuotaNumberWithDigit(sum, 4),
value: renderQuota(sum, 4),
});
return array;
},
@@ -212,19 +217,8 @@ const Detail = (props) => {
created_at: now.getTime() / 1000,
});
}
// 根据dataExportDefaultTime重制时间粒度
let timeGranularity = 3600;
if (dataExportDefaultTime === 'day') {
timeGranularity = 86400;
} else if (dataExportDefaultTime === 'week') {
timeGranularity = 604800;
}
// sort created_at
data.sort((a, b) => a.created_at - b.created_at);
data.forEach((item) => {
item['created_at'] =
Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
});
updateChartData(data);
} else {
showError(message);
@@ -250,14 +244,14 @@ const Detail = (props) => {
let uniqueModels = new Set();
let totalTokens = 0;
// 收集所有唯一的模型名称和时间点
let uniqueTimes = new Set();
// 收集所有唯一的模型名称
data.forEach(item => {
uniqueModels.add(item.model_name);
uniqueTimes.add(timestamp2string1(item.created_at, dataExportDefaultTime));
totalTokens += item.token_used;
totalQuota += item.quota;
totalTimes += item.count;
});
// 处理颜色映射
const newModelColors = {};
Array.from(uniqueModels).forEach((modelName) => {
@@ -267,56 +261,82 @@ const Detail = (props) => {
});
setModelColors(newModelColors);
// 处理饼图数据
for (let item of data) {
totalQuota += item.quota;
totalTimes += item.count;
let pieItem = newPieData.find((it) => it.type === item.model_name);
if (pieItem) {
pieItem.value += item.count;
} else {
newPieData.push({
type: item.model_name,
value: item.count,
// 按时间和模型聚合数据
let aggregatedData = new Map();
data.forEach(item => {
const timeKey = timestamp2string1(item.created_at, dataExportDefaultTime);
const modelKey = item.model_name;
const key = `${timeKey}-${modelKey}`;
if (!aggregatedData.has(key)) {
aggregatedData.set(key, {
time: timeKey,
model: modelKey,
quota: 0,
count: 0
});
}
const existing = aggregatedData.get(key);
existing.quota += item.quota;
existing.count += item.count;
});
// 处理饼图数据
let modelTotals = new Map();
for (let [_, value] of aggregatedData) {
if (!modelTotals.has(value.model)) {
modelTotals.set(value.model, 0);
}
modelTotals.set(value.model, modelTotals.get(value.model) + value.count);
}
// 处理柱状图数据
let timePoints = Array.from(uniqueTimes);
newPieData = Array.from(modelTotals).map(([model, count]) => ({
type: model,
value: count
}));
// 生成时间点序列
let timePoints = Array.from(new Set([...aggregatedData.values()].map(d => d.time)));
if (timePoints.length < 7) {
// 根据时间粒度生成合适的时间点
const generateTimePoints = () => {
let lastTime = Math.max(...data.map(item => item.created_at));
let points = [];
let interval = dataExportDefaultTime === 'hour' ? 3600
const lastTime = Math.max(...data.map(item => item.created_at));
const interval = dataExportDefaultTime === 'hour' ? 3600
: dataExportDefaultTime === 'day' ? 86400
: 604800;
for (let i = 0; i < 7; i++) {
points.push(timestamp2string1(lastTime - (i * interval), dataExportDefaultTime));
}
return points.reverse();
};
timePoints = generateTimePoints();
timePoints = Array.from({length: 7}, (_, i) =>
timestamp2string1(lastTime - (6-i) * interval, dataExportDefaultTime)
);
}
// 为每个时间点和模型生成数据
// 生成柱状图数据
timePoints.forEach(time => {
Array.from(uniqueModels).forEach(model => {
let existingData = data.find(item =>
timestamp2string1(item.created_at, dataExportDefaultTime) === time &&
item.model_name === model
);
newLineData.push({
// 为每个时间点收集所有模型的数据
let timeData = Array.from(uniqueModels).map(model => {
const key = `${time}-${model}`;
const aggregated = aggregatedData.get(key);
return {
Time: time,
Model: model,
Usage: existingData ? parseFloat(getQuotaWithUnit(existingData.quota)) : 0
});
rawQuota: aggregated?.quota || 0,
Usage: aggregated?.quota ? getQuotaWithUnit(aggregated.quota, 4) : 0
};
});
// 计算该时间点的总计
const timeSum = timeData.reduce((sum, item) => sum + item.rawQuota, 0);
// 按照 rawQuota 从大到小排序
timeData.sort((a, b) => b.rawQuota - a.rawQuota);
// 为每个数据点添加该时间的总计
timeData = timeData.map(item => ({
...item,
TimeSum: timeSum
}));
// 将排序后的数据添加到 newLineData
newLineData.push(...timeData);
});
// 排序

View File

@@ -2,11 +2,12 @@ import React, { useCallback, useContext, useEffect, useState } from 'react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../../context/User/index.js';
import { API, getUserIdFromLocalStorage, showError } from '../../helpers/index.js';
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography, Button } from '@douyinfe/semi-ui';
import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography, Button, Highlight } from '@douyinfe/semi-ui';
import { SSE } from 'sse';
import { IconSetting } from '@douyinfe/semi-icons';
import { StyleContext } from '../../context/Style/index.js';
import { useTranslation } from 'react-i18next';
import { renderGroupOption } from '../../helpers/render.js';
const roleInfo = {
user: {
@@ -97,15 +98,17 @@ const Playground = () => {
let res = await API.get(`/api/user/self/groups`);
const { success, message, data } = res.data;
if (success) {
let localGroupOptions = Object.keys(data).map((group) => ({
label: data[group],
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
label: info.desc,
value: group,
ratio: info.ratio
}));
if (localGroupOptions.length === 0) {
localGroupOptions = [{
label: t('用户分组'),
value: '',
ratio: 1
}];
} else {
const localUser = JSON.parse(localStorage.getItem('user'));
@@ -326,12 +329,9 @@ const Playground = () => {
}}
value={inputs.group}
autoComplete='new-password'
optionList={groups.map((group) => ({
...group,
label: styleState.isMobile && group.label.length > 16
? group.label.substring(0, 16) + '...'
: group.label,
}))}
optionList={groups}
renderOptionItem={renderGroupOption}
style={{ width: '100%' }}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>{t('模型')}</Typography.Text>

View File

@@ -7,7 +7,7 @@ import {
showSuccess,
timestamp2string,
} from '../../helpers';
import { renderQuotaWithPrompt } from '../../helpers/render';
import { renderGroupOption, renderQuotaWithPrompt } from '../../helpers/render';
import {
AutoComplete,
Banner,
@@ -23,6 +23,7 @@ import {
} from '@douyinfe/semi-ui';
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import { Divider } from 'semantic-ui-react';
import { useTranslation } from 'react-i18next';
const EditToken = (props) => {
const [isEdit, setIsEdit] = useState(false);
@@ -52,6 +53,7 @@ const EditToken = (props) => {
const [models, setModels] = useState([]);
const [groups, setGroups] = useState([]);
const navigate = useNavigate();
const { t } = useTranslation();
const handleInputChange = (name, value) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
@@ -87,7 +89,7 @@ const EditToken = (props) => {
}));
setModels(localModelOptions);
} else {
showError(message);
showError(t(message));
}
};
@@ -95,15 +97,14 @@ const EditToken = (props) => {
let res = await API.get(`/api/user/self/groups`);
const { success, message, data } = res.data;
if (success) {
// return data is a map, key is group name, value is group description
// label is group description, value is group name
let localGroupOptions = Object.keys(data).map((group) => ({
label: data[group],
value: group,
}));
setGroups(localGroupOptions);
let localGroupOptions = Object.entries(data).map(([group, info]) => ({
label: info.desc,
value: group,
ratio: info.ratio
}));
setGroups(localGroupOptions);
} else {
showError(message);
showError(t(message));
}
};
@@ -176,7 +177,7 @@ const EditToken = (props) => {
if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time);
if (isNaN(time)) {
showError('过期时间格式错误!');
showError(t('过期时间格式错误!'));
setLoading(false);
return;
}
@@ -189,11 +190,11 @@ const EditToken = (props) => {
});
const { success, message } = res.data;
if (success) {
showSuccess('令牌更新成功!');
showSuccess(t('令牌更新成功!'));
props.refresh();
props.handleClose();
} else {
showError(message);
showError(t(message));
}
} else {
// 处理新增多个令牌的情况
@@ -209,7 +210,7 @@ const EditToken = (props) => {
if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time);
if (isNaN(time)) {
showError('过期时间格式错误!');
showError(t('过期时间格式错误!'));
setLoading(false);
break;
}
@@ -222,14 +223,14 @@ const EditToken = (props) => {
if (success) {
successCount++;
} else {
showError(message);
showError(t(message));
break; // 如果创建失败,终止循环
}
}
if (successCount > 0) {
showSuccess(
`${successCount}令牌创建成功,请在列表页面点击复制获取令牌!`,
t('令牌创建成功,请在列表页面点击复制获取令牌!')
);
props.refresh();
props.handleClose();
@@ -245,7 +246,7 @@ const EditToken = (props) => {
<SideSheet
placement={isEdit ? 'right' : 'left'}
title={
<Title level={3}>{isEdit ? '更新令牌信息' : '创建新的令牌'}</Title>
<Title level={3}>{isEdit ? t('更新令牌信息') : t('创建新的令牌')}</Title>
}
headerStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
bodyStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
@@ -254,7 +255,7 @@ const EditToken = (props) => {
<div style={{ display: 'flex', justifyContent: 'flex-end' }}>
<Space>
<Button theme='solid' size={'large'} onClick={submit}>
提交
{t('提交')}
</Button>
<Button
theme='solid'
@@ -262,7 +263,7 @@ const EditToken = (props) => {
type={'tertiary'}
onClick={handleCancel}
>
取消
{t('取消')}
</Button>
</Space>
</div>
@@ -274,9 +275,9 @@ const EditToken = (props) => {
<Spin spinning={loading}>
<Input
style={{ marginTop: 20 }}
label='名称'
label={t('名称')}
name='name'
placeholder={'请输入名称'}
placeholder={t('请输入名称')}
onChange={(value) => handleInputChange('name', value)}
value={name}
autoComplete='new-password'
@@ -284,9 +285,9 @@ const EditToken = (props) => {
/>
<Divider />
<DatePicker
label='过期时间'
label={t('过期时间')}
name='expired_time'
placeholder={'请选择过期时间'}
placeholder={t('请选择过期时间')}
onChange={(value) => handleInputChange('expired_time', value)}
value={expired_time}
autoComplete='new-password'
@@ -300,7 +301,7 @@ const EditToken = (props) => {
setExpiredTime(0, 0, 0, 0);
}}
>
永不过期
{t('永不过期')}
</Button>
<Button
type={'tertiary'}
@@ -308,7 +309,7 @@ const EditToken = (props) => {
setExpiredTime(0, 0, 1, 0);
}}
>
一小时
{t('一小时')}
</Button>
<Button
type={'tertiary'}
@@ -316,7 +317,7 @@ const EditToken = (props) => {
setExpiredTime(1, 0, 0, 0);
}}
>
一个月
{t('一个月')}
</Button>
<Button
type={'tertiary'}
@@ -324,7 +325,7 @@ const EditToken = (props) => {
setExpiredTime(0, 1, 0, 0);
}}
>
一天
{t('一天')}
</Button>
</Space>
</div>
@@ -332,17 +333,15 @@ const EditToken = (props) => {
<Divider />
<Banner
type={'warning'}
description={
'注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。'
}
description={t('注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。')}
></Banner>
<div style={{ marginTop: 20 }}>
<Typography.Text>{`额度${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
<Typography.Text>{`${t('额度')}${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
</div>
<AutoComplete
style={{ marginTop: 8 }}
name='remain_quota'
placeholder={'请输入额度'}
placeholder={t('请输入额度')}
onChange={(value) => handleInputChange('remain_quota', value)}
value={remain_quota}
autoComplete='new-password'
@@ -362,22 +361,22 @@ const EditToken = (props) => {
{!isEdit && (
<>
<div style={{ marginTop: 20 }}>
<Typography.Text>新建数量</Typography.Text>
<Typography.Text>{t('新建数量')}</Typography.Text>
</div>
<AutoComplete
style={{ marginTop: 8 }}
label='数量'
placeholder={'请选择或输入创建令牌的数量'}
label={t('数量')}
placeholder={t('请选择或输入创建令牌的数量')}
onChange={(value) => handleTokenCountChange(value)}
onSelect={(value) => handleTokenCountChange(value)}
value={tokenCount.toString()}
autoComplete='off'
type='number'
data={[
{ value: 10, label: '10个' },
{ value: 20, label: '20个' },
{ value: 30, label: '30个' },
{ value: 100, label: '100个' },
{ value: 10, label: t('10个') },
{ value: 20, label: t('20个') },
{ value: 30, label: t('30个') },
{ value: 100, label: t('100个') },
]}
disabled={unlimited_quota}
/>
@@ -392,17 +391,17 @@ const EditToken = (props) => {
setUnlimitedQuota();
}}
>
{unlimited_quota ? '取消无限额度' : '设为无限额度'}
{unlimited_quota ? t('取消无限额度') : t('设为无限额度')}
</Button>
</div>
<Divider />
<div style={{ marginTop: 10 }}>
<Typography.Text>IP白名单请勿过度信任此功能</Typography.Text>
<Typography.Text>{t('IP白名单请勿过度信任此功能)')}</Typography.Text>
</div>
<TextArea
label='IP白名单'
label={t('IP白名单')}
name='allow_ips'
placeholder={'允许的IP一行一个'}
placeholder={t('允许的IP一行一个')}
onChange={(value) => {
handleInputChange('allow_ips', value);
}}
@@ -417,16 +416,15 @@ const EditToken = (props) => {
onChange={(e) =>
handleInputChange('model_limits_enabled', e.target.checked)
}
></Checkbox>
<Typography.Text>
启用模型限制非必要不建议启用
</Typography.Text>
>
{t('启用模型限制(非必要,不建议启用)')}
</Checkbox>
</Space>
</div>
<Select
style={{ marginTop: 8 }}
placeholder={'请选择该渠道所支持的模型'}
placeholder={t('请选择该渠道所支持的模型')}
name='models'
required
multiple
@@ -440,25 +438,27 @@ const EditToken = (props) => {
disabled={!model_limits_enabled}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text>令牌分组默认为用户的分组</Typography.Text>
<Typography.Text>{t('令牌分组,默认为用户的分组')}</Typography.Text>
</div>
{groups.length > 0 ?
<Select
style={{ marginTop: 8 }}
placeholder={'令牌分组,默认为用户的分组'}
placeholder={t('令牌分组,默认为用户的分组')}
name='gruop'
required
selection
onChange={(value) => {
handleInputChange('group', value);
}}
position={'topLeft'}
renderOptionItem={renderGroupOption}
value={inputs.group}
autoComplete='new-password'
optionList={groups}
/>:
<Select
style={{ marginTop: 8 }}
placeholder={'管理员未设置用户可选分组'}
placeholder={t('管理员未设置用户可选分组')}
name='gruop'
disabled={true}
/>

View File

@@ -19,7 +19,11 @@ const AddUser = (props) => {
const submit = async () => {
setLoading(true);
if (inputs.username === '' || inputs.password === '') return;
if (inputs.username === '' || inputs.password === '') {
setLoading(false);
showError('用户名和密码不能为空!');
return;
}
const res = await API.post(`/api/user/`, inputs);
const { success, message } = res.data;
if (success) {