mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-12 18:57:28 +00:00
Compare commits
148 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a5a839d4d | ||
|
|
5d8a0952b4 | ||
|
|
bd08ecc1e0 | ||
|
|
e4f61c1084 | ||
|
|
a38215478f | ||
|
|
f978d8224e | ||
|
|
ab59887933 | ||
|
|
2b7dff2d94 | ||
|
|
58752d2dcf | ||
|
|
8e9dae7b5f | ||
|
|
02571c20ff | ||
|
|
8201daa4b4 | ||
|
|
5b54624cd5 | ||
|
|
db737567fb | ||
|
|
5c3898d13e | ||
|
|
2fdb2be6d0 | ||
|
|
ab78efc815 | ||
|
|
fcf97d1796 | ||
|
|
e85cc6acbe | ||
|
|
da002e6ca9 | ||
|
|
070e7b6911 | ||
|
|
d5a3eb7d04 | ||
|
|
a87d4271d3 | ||
|
|
b2badad554 | ||
|
|
133d8c9f77 | ||
|
|
9708d645d3 | ||
|
|
0bca4d3efc | ||
|
|
7572e791f6 | ||
|
|
a180d13182 | ||
|
|
edcdb378fd | ||
|
|
4447e51588 | ||
|
|
ba6b0637cc | ||
|
|
3502730dfc | ||
|
|
3746482e8c | ||
|
|
5ed4b60b8f | ||
|
|
547da2da60 | ||
|
|
f88ed4dd5c | ||
|
|
87fc681df3 | ||
|
|
0b9b21eafd | ||
|
|
21f43b0dd8 | ||
|
|
3a7ba5725c | ||
|
|
2e4fa32d63 | ||
|
|
0199896d9a | ||
|
|
edd9049100 | ||
|
|
290c763901 | ||
|
|
226446a3b5 | ||
|
|
ab627db4be | ||
|
|
0f35d2368f | ||
|
|
3c276d13c4 | ||
|
|
b7c3328d43 | ||
|
|
4d8e63bd1a | ||
|
|
51757b83e1 | ||
|
|
87c260093a | ||
|
|
691a878aa2 | ||
|
|
b33d808bc1 | ||
|
|
4559f5b2d3 | ||
|
|
0b9c6ecb00 | ||
|
|
a7d87475af | ||
|
|
ba37750943 | ||
|
|
4fc85d27e9 | ||
|
|
246ca40aac | ||
|
|
59a6fa7274 | ||
|
|
7fa21ce95f | ||
|
|
6b7295bbdf | ||
|
|
b4b6bd46fe | ||
|
|
d5c96cb036 | ||
|
|
1294d286ee | ||
|
|
dc95d0d1e6 | ||
|
|
467439090d | ||
|
|
b77574dad5 | ||
|
|
3ac02879de | ||
|
|
a9160804a3 | ||
|
|
c48a398737 | ||
|
|
e735377218 | ||
|
|
d2b47969da | ||
|
|
af50660887 | ||
|
|
5adf1e272d | ||
|
|
abfb3f4006 | ||
|
|
5f05803643 | ||
|
|
ab0ba9f38c | ||
|
|
e1a93a1b82 | ||
|
|
e6e5f31921 | ||
|
|
8978dc7a8b | ||
|
|
d57e6425e5 | ||
|
|
b9b4b24961 | ||
|
|
4c05377c87 | ||
|
|
a9cdbce9de | ||
|
|
66403275b7 | ||
|
|
c554015526 | ||
|
|
35313ae0d6 | ||
|
|
6c359839cc | ||
|
|
be7e09b14d | ||
|
|
60b624a329 | ||
|
|
47531a6b93 | ||
|
|
0e05f725a4 | ||
|
|
034cc7f118 | ||
|
|
927cd07a3f | ||
|
|
070eba4b4c | ||
|
|
af9cc5ce11 | ||
|
|
f844772126 | ||
|
|
a8a2141626 | ||
|
|
0401f1e9ec | ||
|
|
358af20ad1 | ||
|
|
e455f06851 | ||
|
|
f191f981c4 | ||
|
|
9b659ed4f1 | ||
|
|
d39b52272e | ||
|
|
a0ae6644ee | ||
|
|
1a7da8397b | ||
|
|
dcefd7dfb4 | ||
|
|
21edb75081 | ||
|
|
a28ab3628a | ||
|
|
856465ae59 | ||
|
|
3123d4bb9b | ||
|
|
dd21183261 | ||
|
|
ef4b0bc371 | ||
|
|
3d6859b865 | ||
|
|
0389e76af5 | ||
|
|
a1163dd735 | ||
|
|
a9a284a595 | ||
|
|
95bac28232 | ||
|
|
5bf5419633 | ||
|
|
48817648c3 | ||
|
|
4baaf456a7 | ||
|
|
52356a1b92 | ||
|
|
bdb7c9cbd7 | ||
|
|
a7b17eb1ba | ||
|
|
8ed68e4b12 | ||
|
|
f124404f07 | ||
|
|
3f89ee66e1 | ||
|
|
7c0302b5f8 | ||
|
|
26b70d6a25 | ||
|
|
2509f644bc | ||
|
|
896e1d978f | ||
|
|
6c4f64c397 | ||
|
|
d1f493bf17 | ||
|
|
56188c33b5 | ||
|
|
d9461a477d | ||
|
|
07b47fbf3a | ||
|
|
66d3206d7d | ||
|
|
136a46218b | ||
|
|
3f67db1028 | ||
|
|
936e593a4f | ||
|
|
9ff33405ec | ||
|
|
f25b084d40 | ||
|
|
fe00434454 | ||
|
|
f2957ee558 | ||
|
|
b605ff9b02 |
@@ -27,6 +27,9 @@
|
||||
<a href="https://goreportcard.com/report/github.com/Calcium-Ion/new-api">
|
||||
<img src="https://goreportcard.com/badge/github.com/Calcium-Ion/new-api" alt="GoReportCard">
|
||||
</a>
|
||||
<a href="https://coderabbit.ai">
|
||||
<img src="https://img.shields.io/coderabbit/prs/github/QuantumNous/new-api?utm_source=oss&utm_medium=github&utm_campaign=QuantumNous%2Fnew-api&labelColor=171717&color=FF570A&link=https%3A%2F%2Fcoderabbit.ai&label=CodeRabbit+Reviews" alt="CodeRabbit Pull Request Reviews">
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -180,7 +183,6 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
|
||||
其他基于New API的项目:
|
||||
- [new-api-horizon](https://github.com/Calcium-Ion/new-api-horizon):New API高性能优化版
|
||||
- [VoAPI](https://github.com/VoAPI/VoAPI):基于New API的前端美化版本
|
||||
|
||||
## 帮助支持
|
||||
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
package common
|
||||
|
||||
const (
|
||||
DatabaseTypeMySQL = "mysql"
|
||||
DatabaseTypeSQLite = "sqlite"
|
||||
DatabaseTypePostgreSQL = "postgres"
|
||||
)
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
|
||||
@@ -141,7 +141,11 @@ func RedisHSetObj(key string, obj interface{}, expiration time.Duration) error {
|
||||
|
||||
txn := RDB.TxPipeline()
|
||||
txn.HSet(ctx, key, data)
|
||||
txn.Expire(ctx, key, expiration)
|
||||
|
||||
// 只有在 expiration 大于 0 时才设置过期时间
|
||||
if expiration > 0 {
|
||||
txn.Expire(ctx, key, expiration)
|
||||
}
|
||||
|
||||
_, err := txn.Exec(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -249,13 +249,38 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||
}
|
||||
|
||||
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||
func GetAudioDuration(ctx context.Context, filename string, ext string) (float64, error) {
|
||||
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package constant
|
||||
|
||||
import "one-api/common"
|
||||
|
||||
var (
|
||||
TokenCacheSeconds = common.SyncFrequency
|
||||
UserId2GroupCacheSeconds = common.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = common.SyncFrequency
|
||||
UserId2StatusCacheSeconds = common.SyncFrequency
|
||||
)
|
||||
// 使用函数来避免初始化顺序带来的赋值问题
|
||||
func RedisKeyCacheSeconds() int {
|
||||
return common.SyncFrequency
|
||||
}
|
||||
|
||||
// Cache keys
|
||||
const (
|
||||
|
||||
@@ -7,6 +7,7 @@ var (
|
||||
UserSettingWebhookSecret = "webhook_secret" // WebhookSecret webhook密钥
|
||||
UserSettingNotificationEmail = "notification_email" // NotificationEmail 通知邮箱地址
|
||||
UserAcceptUnsetRatioModel = "accept_unset_model_ratio_model" // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
UserSettingRecordIpLog = "record_ip_log" // 是否记录请求和错误日志IP
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -165,8 +165,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice)
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
@@ -271,6 +271,13 @@ func testAllChannels(notify bool) error {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
}
|
||||
gopool.Go(func() {
|
||||
// 使用 defer 确保无论如何都会重置运行状态,防止死锁
|
||||
defer func() {
|
||||
testAllChannelsLock.Lock()
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
}()
|
||||
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
@@ -305,9 +312,7 @@ func testAllChannels(notify bool) error {
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
testAllChannelsLock.Lock()
|
||||
testAllChannelsRunning = false
|
||||
testAllChannelsLock.Unlock()
|
||||
|
||||
if notify {
|
||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||
}
|
||||
|
||||
@@ -43,22 +43,31 @@ type OpenAIModelsResponse struct {
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if pageSize < 0 {
|
||||
if pageSize < 1 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
tags, err := model.GetPaginatedTags(p*pageSize, pageSize)
|
||||
// tag 分页:先分页 tag,再取各 tag 下 channels
|
||||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
@@ -69,21 +78,39 @@ func GetAllChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channels, err := model.GetAllChannels(p*pageSize, pageSize, false, idSort)
|
||||
// 计算 tag 总数用于分页
|
||||
total, _ = model.CountAllTags()
|
||||
} else if typeFilter >= 0 {
|
||||
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountChannelsByType(typeFilter)
|
||||
} else {
|
||||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountAllChannels()
|
||||
}
|
||||
|
||||
// calculate type counts
|
||||
typeCounts, _ := model.CountChannelsGroupByType()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -210,10 +237,20 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
103
controller/console_migrate.go
Normal file
103
controller/console_migrate.go
Normal file
@@ -0,0 +1,103 @@
|
||||
// 用于迁移检测的旧键,该文件下个版本会删除
|
||||
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// MigrateConsoleSetting 迁移旧的控制台相关配置到 console_setting.*
|
||||
func MigrateConsoleSetting(c *gin.Context) {
|
||||
// 读取全部 option
|
||||
opts, err := model.AllOption()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// 建立 map
|
||||
valMap := map[string]string{}
|
||||
for _, o := range opts {
|
||||
valMap[o.Key] = o.Value
|
||||
}
|
||||
|
||||
// 处理 APIInfo
|
||||
if v := valMap["ApiInfo"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
if len(arr) > 50 {
|
||||
arr = arr[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(arr)
|
||||
model.UpdateOption("console_setting.api_info", string(bytes))
|
||||
}
|
||||
model.UpdateOption("ApiInfo", "")
|
||||
}
|
||||
// Announcements 直接搬
|
||||
if v := valMap["Announcements"]; v != "" {
|
||||
model.UpdateOption("console_setting.announcements", v)
|
||||
model.UpdateOption("Announcements", "")
|
||||
}
|
||||
// FAQ 转换
|
||||
if v := valMap["FAQ"]; v != "" {
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(v), &arr); err == nil {
|
||||
out := []map[string]interface{}{}
|
||||
for _, item := range arr {
|
||||
q, _ := item["question"].(string)
|
||||
if q == "" {
|
||||
q, _ = item["title"].(string)
|
||||
}
|
||||
a, _ := item["answer"].(string)
|
||||
if a == "" {
|
||||
a, _ = item["content"].(string)
|
||||
}
|
||||
if q != "" && a != "" {
|
||||
out = append(out, map[string]interface{}{"question": q, "answer": a})
|
||||
}
|
||||
}
|
||||
if len(out) > 50 {
|
||||
out = out[:50]
|
||||
}
|
||||
bytes, _ := json.Marshal(out)
|
||||
model.UpdateOption("console_setting.faq", string(bytes))
|
||||
}
|
||||
model.UpdateOption("FAQ", "")
|
||||
}
|
||||
// Uptime Kuma 迁移到新的 groups 结构(console_setting.uptime_kuma_groups)
|
||||
url := valMap["UptimeKumaUrl"]
|
||||
slug := valMap["UptimeKumaSlug"]
|
||||
if url != "" && slug != "" {
|
||||
// 仅当同时存在 URL 与 Slug 时才进行迁移
|
||||
groups := []map[string]interface{}{
|
||||
{
|
||||
"id": 1,
|
||||
"categoryName": "old",
|
||||
"url": url,
|
||||
"slug": slug,
|
||||
"description": "",
|
||||
},
|
||||
}
|
||||
bytes, _ := json.Marshal(groups)
|
||||
model.UpdateOption("console_setting.uptime_kuma_groups", string(bytes))
|
||||
}
|
||||
// 清空旧键内容
|
||||
if url != "" {
|
||||
model.UpdateOption("UptimeKumaUrl", "")
|
||||
}
|
||||
if slug != "" {
|
||||
model.UpdateOption("UptimeKumaSlug", "")
|
||||
}
|
||||
|
||||
// 删除旧键记录
|
||||
oldKeys := []string{"ApiInfo", "Announcements", "FAQ", "UptimeKumaUrl", "UptimeKumaSlug"}
|
||||
model.DB.Where("key IN ?", oldKeys).Delete(&model.Option{})
|
||||
|
||||
// 重新加载 OptionMap
|
||||
model.InitOptionMap()
|
||||
common.SysLog("console setting migrated")
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "migrated"})
|
||||
}
|
||||
@@ -1,15 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if setting.GroupInUserUsableGroups("auto") {
|
||||
usableGroups["auto"] = map[string]interface{}{
|
||||
"ratio": "自动",
|
||||
"desc": setting.GetUsableGroupDescription("auto"),
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
@@ -215,8 +214,12 @@ func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask dto.MidjourneyDto)
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
// 解析其他查询参数
|
||||
@@ -227,31 +230,38 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
items := model.GetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.CountAllTasks(queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
log.Printf("userId = %d \n", userId)
|
||||
|
||||
queryParams := model.TaskQueryParams{
|
||||
MjID: c.Query("mj_id"),
|
||||
@@ -259,19 +269,23 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
EndTimestamp: c.Query("end_timestamp"),
|
||||
}
|
||||
|
||||
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
items := model.GetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.CountAllUserTask(userId, queryParams)
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range logs {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
@@ -24,58 +26,85 @@ func TestStatus(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// 获取HTTP统计信息
|
||||
httpStats := middleware.GetStats()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
"success": true,
|
||||
"message": "Server is running",
|
||||
"http_stats": httpStats,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
}
|
||||
|
||||
// 根据启用状态注入可选内容
|
||||
if cs.ApiInfoEnabled {
|
||||
data["api_info"] = console_setting.GetApiInfo()
|
||||
}
|
||||
if cs.AnnouncementsEnabled {
|
||||
data["announcements"] = console_setting.GetAnnouncements()
|
||||
}
|
||||
if cs.FAQEnabled {
|
||||
data["faq"] = console_setting.GetFAQ()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
"setup": constant.Setup,
|
||||
"api_info": setting.GetApiInfo(),
|
||||
},
|
||||
"data": data,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -15,6 +14,9 @@ import (
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -179,7 +181,19 @@ func ListModels(c *gin.Context) {
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
models := model.GetGroupModels(group)
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
@@ -102,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -119,8 +121,35 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "ApiInfo":
|
||||
err = setting.ValidateApiInfo(option.Value)
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.announcements":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.faq":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
case "console_setting.uptime_kuma_groups":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -13,6 +12,8 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
@@ -57,9 +58,9 @@ func Playground(c *gin.Context) {
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetPricing(c *gin.Context) {
|
||||
@@ -12,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -20,12 +21,18 @@ func GetPricing(c *gin.Context) {
|
||||
user, err := model.GetUserCache(userId.(int))
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
@@ -40,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -49,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -126,6 +127,10 @@ func AddRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
var keys []string
|
||||
for i := 0; i < redemption.Count; i++ {
|
||||
key := common.GetUUID()
|
||||
@@ -135,6 +140,7 @@ func AddRedemption(c *gin.Context) {
|
||||
Key: key,
|
||||
CreatedTime: common.GetTimestamp(),
|
||||
Quota: redemption.Quota,
|
||||
ExpiredTime: redemption.ExpiredTime,
|
||||
}
|
||||
err = cleanRedemption.Insert()
|
||||
if err != nil {
|
||||
@@ -191,12 +197,18 @@ func UpdateRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
} else {
|
||||
if statusOnly == "" {
|
||||
if err := validateExpiredTime(redemption.ExpiredTime); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
// If you add more fields, please also update redemption.Update()
|
||||
cleanRedemption.Name = redemption.Name
|
||||
cleanRedemption.Quota = redemption.Quota
|
||||
cleanRedemption.ExpiredTime = redemption.ExpiredTime
|
||||
}
|
||||
if statusOnly != "" {
|
||||
cleanRedemption.Status = redemption.Status
|
||||
}
|
||||
err = cleanRedemption.Update()
|
||||
if err != nil {
|
||||
@@ -213,3 +225,27 @@ func UpdateRedemption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteInvalidRedemption(c *gin.Context) {
|
||||
rows, err := model.DeleteInvalidRedemptions()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func validateExpiredTime(expired int64) error {
|
||||
if expired != 0 && expired < common.GetTimestamp() {
|
||||
return errors.New("过期时间不能早于当前时间")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
break
|
||||
|
||||
@@ -75,6 +75,14 @@ func PostSetup(c *gin.Context) {
|
||||
|
||||
// If root doesn't exist, validate and create admin account
|
||||
if !rootExists {
|
||||
// Validate username length: max 12 characters to align with model.User validation
|
||||
if len(req.Username) > 12 {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名长度不能超过12个字符",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Validate password
|
||||
if req.Password != req.ConfirmPassword {
|
||||
c.JSON(400, gin.H{
|
||||
|
||||
@@ -224,9 +224,14 @@ func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool
|
||||
|
||||
func GetAllTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
// 解析其他查询参数
|
||||
@@ -237,24 +242,32 @@ func GetAllTask(c *gin.Context) {
|
||||
Action: c.Query("action"),
|
||||
StartTimestamp: startTimestamp,
|
||||
EndTimestamp: endTimestamp,
|
||||
ChannelID: c.Query("channel_id"),
|
||||
}
|
||||
|
||||
logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Task, 0)
|
||||
}
|
||||
items := model.TaskGetAllTasks((p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.TaskCountAllTasks(queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserTask(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
if pageSize <= 0 {
|
||||
pageSize = common.ItemsPerPage
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
@@ -271,14 +284,17 @@ func GetUserTask(c *gin.Context) {
|
||||
EndTimestamp: endTimestamp,
|
||||
}
|
||||
|
||||
logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Task, 0)
|
||||
}
|
||||
items := model.TaskGetAllUserTask(userId, (p-1)*pageSize, pageSize, queryParams)
|
||||
total := model.TaskCountAllUserTask(userId, queryParams)
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
"data": gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,15 +12,15 @@ func GetAllTokens(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
size, _ := strconv.Atoi(c.Query("size"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
if p < 1 {
|
||||
p = 1
|
||||
}
|
||||
if size <= 0 {
|
||||
size = common.ItemsPerPage
|
||||
} else if size > 100 {
|
||||
size = 100
|
||||
}
|
||||
tokens, err := model.GetAllUserTokens(userId, p*size, size)
|
||||
tokens, err := model.GetAllUserTokens(userId, (p-1)*size, size)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -28,10 +28,18 @@ func GetAllTokens(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Get total count for pagination
|
||||
total, _ := model.CountUserTokens(userId)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": tokens,
|
||||
"data": gin.H{
|
||||
"items": tokens,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": size,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func RequestEpay(c *gin.Context) {
|
||||
payType = "wxpay"
|
||||
}
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
|
||||
154
controller/uptime_kuma.go
Normal file
154
controller/uptime_kuma.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/setting/console_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
requestTimeout = 30 * time.Second
|
||||
httpTimeout = 10 * time.Second
|
||||
uptimeKeySuffix = "_24"
|
||||
apiStatusPath = "/api/status-page/"
|
||||
apiHeartbeatPath = "/api/status-page/heartbeat/"
|
||||
)
|
||||
|
||||
type Monitor struct {
|
||||
Name string `json:"name"`
|
||||
Uptime float64 `json:"uptime"`
|
||||
Status int `json:"status"`
|
||||
Group string `json:"group,omitempty"`
|
||||
}
|
||||
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return errors.New("non-200 status")
|
||||
}
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(dest)
|
||||
}
|
||||
|
||||
func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[string]interface{}) UptimeGroupResult {
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
} `json:"heartbeatList"`
|
||||
UptimeList map[string]float64 `json:"uptimeList"`
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
if g.Wait() != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, pg := range statusData.PublicGroupList {
|
||||
if len(pg.MonitorList) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range pg.MonitorList {
|
||||
monitor := Monitor{
|
||||
Name: m.Name,
|
||||
Group: pg.Name,
|
||||
}
|
||||
|
||||
monitorID := strconv.Itoa(m.ID)
|
||||
|
||||
if uptime, exists := heartbeatData.UptimeList[monitorID+uptimeKeySuffix]; exists {
|
||||
monitor.Uptime = uptime
|
||||
}
|
||||
|
||||
if heartbeats, exists := heartbeatData.HeartbeatList[monitorID]; exists && len(heartbeats) > 0 {
|
||||
monitor.Status = heartbeats[0].Status
|
||||
}
|
||||
|
||||
result.Monitors = append(result.Monitors, monitor)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func GetUptimeKumaStatus(c *gin.Context) {
|
||||
groups := console_setting.GetUptimeKumaGroups()
|
||||
if len(groups) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": []UptimeGroupResult{}})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), requestTimeout)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
g.Go(func() error {
|
||||
results[i] = fetchGroupData(gCtx, client, group)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if setting.DefaultUseAutoGroup {
|
||||
token.Group = "auto"
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -459,6 +462,9 @@ func GetSelf(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
@@ -943,6 +949,7 @@ type UpdateUserSettingRequest struct {
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
|
||||
func UpdateUserSetting(c *gin.Context) {
|
||||
@@ -1019,6 +1026,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
constant.UserSettingNotifyType: req.QuotaWarningType,
|
||||
constant.UserSettingQuotaWarningThreshold: req.QuotaWarningThreshold,
|
||||
"accept_unset_model_ratio_model": req.AcceptUnsetModelRatioModel,
|
||||
constant.UserSettingRecordIpLog: req.RecordIpLog,
|
||||
}
|
||||
|
||||
// 如果是webhook类型,添加webhook相关设置
|
||||
|
||||
@@ -178,7 +178,14 @@ type ClaudeRequest struct {
|
||||
|
||||
type Thinking struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens"`
|
||||
BudgetTokens *int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (c *Thinking) GetBudgetTokens() int {
|
||||
if c.BudgetTokens == nil {
|
||||
return 0
|
||||
}
|
||||
return *c.BudgetTokens
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) IsStringSystem() bool {
|
||||
|
||||
@@ -53,10 +53,13 @@ type GeneralOpenAIRequest struct {
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ToMap() map[string]any {
|
||||
|
||||
6
go.mod
6
go.mod
@@ -11,7 +11,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||
github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b
|
||||
github.com/bytedance/sonic v1.11.6
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-contrib/gzip v0.0.6
|
||||
github.com/gin-contrib/sessions v0.0.5
|
||||
@@ -25,10 +24,10 @@ require (
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.7
|
||||
github.com/samber/lo v1.39.0
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
golang.org/x/crypto v0.35.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.35.0
|
||||
@@ -43,12 +42,13 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||
github.com/aws/smithy-go v1.20.2 // indirect
|
||||
github.com/bytedance/sonic v1.11.6 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.11.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -38,8 +38,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI=
|
||||
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
|
||||
@@ -167,8 +167,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
|
||||
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
@@ -197,6 +195,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
|
||||
github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
|
||||
8
main.go
8
main.go
@@ -12,7 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -74,7 +74,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitRatioSettings()
|
||||
ratio_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
@@ -105,10 +105,12 @@ func main() {
|
||||
model.InitChannelCache()
|
||||
}()
|
||||
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
go model.SyncChannelCache(common.SyncFrequency)
|
||||
}
|
||||
|
||||
// 热更新配置
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
|
||||
// 数据看板
|
||||
go model.UpdateQuotaData()
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -48,9 +49,11 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
showGroup := userGroup
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
|
||||
41
middleware/stats.go
Normal file
41
middleware/stats.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HTTPStats 存储HTTP统计信息
|
||||
type HTTPStats struct {
|
||||
activeConnections int64
|
||||
}
|
||||
|
||||
var globalStats = &HTTPStats{}
|
||||
|
||||
// StatsMiddleware 统计中间件
|
||||
func StatsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 增加活跃连接数
|
||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||
|
||||
// 确保在请求结束时减少连接数
|
||||
defer func() {
|
||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// StatsInfo 统计信息结构
|
||||
type StatsInfo struct {
|
||||
ActiveConnections int64 `json:"active_connections"`
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func GetStats() StatsInfo {
|
||||
return StatsInfo{
|
||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
@@ -23,7 +24,7 @@ type Ability struct {
|
||||
func GetGroupModels(group string) []string {
|
||||
var models []string
|
||||
// Find distinct models
|
||||
DB.Table("abilities").Where(groupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
||||
return models
|
||||
}
|
||||
|
||||
@@ -41,16 +42,12 @@ func GetAllEnableAbilities() []Ability {
|
||||
}
|
||||
|
||||
func getPriority(group string, model string, retry int) (int, error) {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var priorities []int
|
||||
err := DB.Model(&Ability{}).
|
||||
Select("DISTINCT(priority)").
|
||||
Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
|
||||
Order("priority DESC"). // 按优先级降序排序
|
||||
Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
|
||||
|
||||
if err != nil {
|
||||
@@ -75,18 +72,14 @@ func getPriority(group string, model string, retry int) (int, error) {
|
||||
}
|
||||
|
||||
func getChannelQuery(group string, model string, retry int) *gorm.DB {
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
trueVal = "true"
|
||||
}
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
|
||||
channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
|
||||
if retry != 0 {
|
||||
priority, err := getPriority(group, model, retry)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
|
||||
} else {
|
||||
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
|
||||
channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,9 +126,15 @@ func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel,
|
||||
func (channel *Channel) AddAbilities() error {
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
key := group + "|" + model
|
||||
if _, exists := abilitySet[key]; exists {
|
||||
continue
|
||||
}
|
||||
abilitySet[key] = struct{}{}
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
@@ -152,7 +151,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err := DB.Create(&chunk).Error
|
||||
err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -194,9 +193,15 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
// Then add new abilities
|
||||
models_ := strings.Split(channel.Models, ",")
|
||||
groups_ := strings.Split(channel.Group, ",")
|
||||
abilitySet := make(map[string]struct{})
|
||||
abilities := make([]Ability, 0, len(models_))
|
||||
for _, model := range models_ {
|
||||
for _, group := range groups_ {
|
||||
key := group + "|" + model
|
||||
if _, exists := abilitySet[key]; exists {
|
||||
continue
|
||||
}
|
||||
abilitySet[key] = struct{}{}
|
||||
ability := Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
@@ -212,7 +217,7 @@ func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
||||
|
||||
if len(abilities) > 0 {
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err = tx.Create(&chunk).Error
|
||||
err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
||||
if err != nil {
|
||||
if isNewTx {
|
||||
tx.Rollback()
|
||||
|
||||
@@ -5,10 +5,13 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||||
var channel *Channel
|
||||
var err error
|
||||
selectGroup := group
|
||||
if group == "auto" {
|
||||
if len(setting.AutoGroups) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
if common.DebugEnabled {
|
||||
println("autoGroup:", autoGroup)
|
||||
}
|
||||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||||
if channel == nil {
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
selectGroup = autoGroup
|
||||
if common.DebugEnabled {
|
||||
println("selectGroup:", selectGroup)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, group, err
|
||||
}
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, group, errors.New("channel not found")
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
}
|
||||
|
||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
@@ -153,15 +153,15 @@ func SearchChannels(keyword string, group string, model string, idSort bool) ([]
|
||||
if group != "" && group != "null" {
|
||||
var groupCondition string
|
||||
if common.UsingMySQL {
|
||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||
} else {
|
||||
// sqlite, PostgreSQL
|
||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||
}
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||
} else {
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||
}
|
||||
|
||||
@@ -478,7 +478,7 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
}
|
||||
|
||||
// 构造基础查询
|
||||
baseQuery := DB.Model(&Channel{}).Omit(keyCol)
|
||||
baseQuery := DB.Model(&Channel{}).Omit("key")
|
||||
|
||||
// 构造WHERE子句
|
||||
var whereClause string
|
||||
@@ -486,15 +486,15 @@ func SearchTags(keyword string, group string, model string, idSort bool) ([]*str
|
||||
if group != "" && group != "null" {
|
||||
var groupCondition string
|
||||
if common.UsingMySQL {
|
||||
groupCondition = `CONCAT(',', ` + groupCol + `, ',') LIKE ?`
|
||||
groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?`
|
||||
} else {
|
||||
// sqlite, PostgreSQL
|
||||
groupCondition = `(',' || ` + groupCol + ` || ',') LIKE ?`
|
||||
groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?`
|
||||
}
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%")
|
||||
} else {
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?"
|
||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%")
|
||||
}
|
||||
|
||||
@@ -583,3 +583,53 @@ func BatchSetChannelTag(ids []int, tag *string) error {
|
||||
// 提交事务
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// CountAllChannels returns total channels in DB
|
||||
func CountAllChannels() (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Channel{}).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// CountAllTags returns number of non-empty distinct tags
|
||||
func CountAllTags() (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Get channels of specified type with pagination
|
||||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// Count channels of specific type
|
||||
func CountChannelsByType(channelType int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Return map[type]count for all channels
|
||||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||
type result struct {
|
||||
Type int64 `gorm:"column:type"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.Type] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
55
model/log.go
55
model/log.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -32,6 +33,7 @@ type Log struct {
|
||||
ChannelName string `json:"channel_name" gorm:"->"`
|
||||
TokenId int `json:"token_id" gorm:"default:0;index"`
|
||||
Group string `json:"group" gorm:"index"`
|
||||
Ip string `json:"ip" gorm:"index;default:''"`
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
@@ -61,7 +63,7 @@ func formatUserLogs(logs []*Log) {
|
||||
func GetLogByKey(key string) (logs []*Log, err error) {
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
var tk Token
|
||||
if err = DB.Model(&Token{}).Where(keyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
||||
@@ -95,6 +97,15 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
common.LogInfo(c, fmt.Sprintf("record error log: userId=%d, channelId=%d, modelName=%s, tokenName=%s, content=%s", userId, channelId, modelName, tokenName, content))
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -111,7 +122,13 @@ func RecordErrorLog(c *gin.Context, userId int, channelId int, modelName string,
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -128,6 +145,15 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
}
|
||||
username := c.GetString("username")
|
||||
otherStr := common.MapToJsonStr(other)
|
||||
// 判断是否需要记录 IP
|
||||
needRecordIp := false
|
||||
if settingMap, err := GetUserSetting(userId, false); err == nil {
|
||||
if v, ok := settingMap[constant.UserSettingRecordIpLog]; ok {
|
||||
if vb, ok := v.(bool); ok && vb {
|
||||
needRecordIp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: username,
|
||||
@@ -144,7 +170,13 @@ func RecordConsumeLog(c *gin.Context, userId int, channelId int, promptTokens in
|
||||
UseTime: useTimeSeconds,
|
||||
IsStream: isStream,
|
||||
Group: group,
|
||||
Other: otherStr,
|
||||
Ip: func() string {
|
||||
if needRecordIp {
|
||||
return c.ClientIP()
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
Other: otherStr,
|
||||
}
|
||||
err := LOG_DB.Create(log).Error
|
||||
if err != nil {
|
||||
@@ -184,7 +216,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = tx.Where("logs.channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
@@ -195,13 +227,18 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0)
|
||||
channelIdsMap := make(map[int]struct{})
|
||||
channelMap := make(map[int]string)
|
||||
for _, log := range logs {
|
||||
if log.ChannelId != 0 {
|
||||
channelIds = append(channelIds, log.ChannelId)
|
||||
channelIdsMap[log.ChannelId] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
channelIds := make([]int, 0, len(channelIdsMap))
|
||||
for channelId := range channelIdsMap {
|
||||
channelIds = append(channelIds, channelId)
|
||||
}
|
||||
if len(channelIds) > 0 {
|
||||
var channels []struct {
|
||||
Id int `gorm:"column:id"`
|
||||
@@ -242,7 +279,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
||||
tx = tx.Where("logs.created_at <= ?", endTimestamp)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where("logs."+groupCol+" = ?", group)
|
||||
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
||||
}
|
||||
err = tx.Model(&Log{}).Count(&total).Error
|
||||
if err != nil {
|
||||
@@ -303,8 +340,8 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||
}
|
||||
if group != "" {
|
||||
tx = tx.Where(groupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
|
||||
tx = tx.Where(logGroupCol+" = ?", group)
|
||||
rpmTpmQuery = rpmTpmQuery.Where(logGroupCol+" = ?", group)
|
||||
}
|
||||
|
||||
tx = tx.Where("type = ?", LogTypeConsume)
|
||||
|
||||
172
model/main.go
172
model/main.go
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -15,18 +16,48 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var groupCol string
|
||||
var keyCol string
|
||||
var commonGroupCol string
|
||||
var commonKeyCol string
|
||||
var commonTrueVal string
|
||||
var commonFalseVal string
|
||||
|
||||
var logKeyCol string
|
||||
var logGroupCol string
|
||||
|
||||
func initCol() {
|
||||
// init common column names
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
keyCol = `"key"`
|
||||
|
||||
commonGroupCol = `"group"`
|
||||
commonKeyCol = `"key"`
|
||||
commonTrueVal = "true"
|
||||
commonFalseVal = "false"
|
||||
} else {
|
||||
groupCol = "`group`"
|
||||
keyCol = "`key`"
|
||||
commonGroupCol = "`group`"
|
||||
commonKeyCol = "`key`"
|
||||
commonTrueVal = "1"
|
||||
commonFalseVal = "0"
|
||||
}
|
||||
if os.Getenv("LOG_SQL_DSN") != "" {
|
||||
switch common.LogSqlType {
|
||||
case common.DatabaseTypePostgreSQL:
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
default:
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
} else {
|
||||
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
||||
if common.UsingPostgreSQL {
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
} else {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
}
|
||||
|
||||
var DB *gorm.DB
|
||||
@@ -83,7 +114,7 @@ func CheckSetup() {
|
||||
}
|
||||
}
|
||||
|
||||
func chooseDB(envName string) (*gorm.DB, error) {
|
||||
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
|
||||
defer func() {
|
||||
initCol()
|
||||
}()
|
||||
@@ -92,7 +123,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
if !isLog {
|
||||
common.UsingPostgreSQL = true
|
||||
} else {
|
||||
common.LogSqlType = common.DatabaseTypePostgreSQL
|
||||
}
|
||||
return gorm.Open(postgres.New(postgres.Config{
|
||||
DSN: dsn,
|
||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||
@@ -102,7 +137,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
}
|
||||
if strings.HasPrefix(dsn, "local") {
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
if !isLog {
|
||||
common.UsingSQLite = true
|
||||
} else {
|
||||
common.LogSqlType = common.DatabaseTypeSQLite
|
||||
}
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -117,7 +156,11 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
common.UsingMySQL = true
|
||||
if !isLog {
|
||||
common.UsingMySQL = true
|
||||
} else {
|
||||
common.LogSqlType = common.DatabaseTypeMySQL
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -131,7 +174,7 @@ func chooseDB(envName string) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
func InitDB() (err error) {
|
||||
db, err := chooseDB("SQL_DSN")
|
||||
db, err := chooseDB("SQL_DSN", false)
|
||||
if err == nil {
|
||||
if common.DebugEnabled {
|
||||
db = db.Debug()
|
||||
@@ -149,7 +192,7 @@ func InitDB() (err error) {
|
||||
return nil
|
||||
}
|
||||
if common.UsingMySQL {
|
||||
_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
err = migrateDB()
|
||||
@@ -165,7 +208,7 @@ func InitLogDB() (err error) {
|
||||
LOG_DB = DB
|
||||
return
|
||||
}
|
||||
db, err := chooseDB("LOG_SQL_DSN")
|
||||
db, err := chooseDB("LOG_SQL_DSN", true)
|
||||
if err == nil {
|
||||
if common.DebugEnabled {
|
||||
db = db.Debug()
|
||||
@@ -198,54 +241,73 @@ func InitLogDB() (err error) {
|
||||
}
|
||||
|
||||
func migrateDB() error {
|
||||
err := DB.AutoMigrate(&Channel{})
|
||||
if !common.UsingPostgreSQL {
|
||||
return migrateDBFast()
|
||||
}
|
||||
err := DB.AutoMigrate(
|
||||
&Channel{},
|
||||
&Token{},
|
||||
&User{},
|
||||
&Option{},
|
||||
&Redemption{},
|
||||
&Ability{},
|
||||
&Log{},
|
||||
&Midjourney{},
|
||||
&TopUp{},
|
||||
&QuotaData{},
|
||||
&Task{},
|
||||
&Setup{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Token{})
|
||||
if err != nil {
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateDBFast() error {
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, 12) // Buffer size matches number of migrations
|
||||
|
||||
migrations := []struct {
|
||||
model interface{}
|
||||
name string
|
||||
}{
|
||||
{&Channel{}, "Channel"},
|
||||
{&Token{}, "Token"},
|
||||
{&User{}, "User"},
|
||||
{&Option{}, "Option"},
|
||||
{&Redemption{}, "Redemption"},
|
||||
{&Ability{}, "Ability"},
|
||||
{&Log{}, "Log"},
|
||||
{&Midjourney{}, "Midjourney"},
|
||||
{&TopUp{}, "TopUp"},
|
||||
{&QuotaData{}, "QuotaData"},
|
||||
{&Task{}, "Task"},
|
||||
{&Setup{}, "Setup"},
|
||||
}
|
||||
err = DB.AutoMigrate(&User{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
for _, m := range migrations {
|
||||
wg.Add(1)
|
||||
go func(model interface{}, name string) {
|
||||
defer wg.Done()
|
||||
if err := DB.AutoMigrate(model); err != nil {
|
||||
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
|
||||
}
|
||||
}(m.model, m.name)
|
||||
}
|
||||
err = DB.AutoMigrate(&Option{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
// Wait for all migrations to complete
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Check for any errors
|
||||
for err := range errChan {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = DB.AutoMigrate(&Redemption{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Ability{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Log{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Midjourney{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&TopUp{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&QuotaData{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Task{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.AutoMigrate(&Setup{})
|
||||
common.SysLog("database migrated")
|
||||
//err = createRootAccountIfNeed()
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateLOGDB() error {
|
||||
|
||||
@@ -166,3 +166,40 @@ func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
|
||||
Where("id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
// CountAllTasks returns total midjourney tasks for admin query
|
||||
func CountAllTasks(queryParams TaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Midjourney{})
|
||||
if queryParams.ChannelID != "" {
|
||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||
}
|
||||
if queryParams.MjID != "" {
|
||||
query = query.Where("mj_id = ?", queryParams.MjID)
|
||||
}
|
||||
if queryParams.StartTimestamp != "" {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != "" {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
// CountAllUserTask returns total midjourney tasks for user
|
||||
func CountAllUserTask(userId int, queryParams TaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Midjourney{}).Where("user_id = ?", userId)
|
||||
if queryParams.MjID != "" {
|
||||
query = query.Where("mj_id = ?", queryParams.MjID)
|
||||
}
|
||||
if queryParams.StartTimestamp != "" {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != "" {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -76,6 +77,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -94,12 +98,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -122,7 +127,6 @@ func InitOptionMap() {
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["ApiInfo"] = ""
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||
@@ -192,7 +196,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ImageDownloadPermission = intValue
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
|
||||
boolValue := value == "true"
|
||||
switch key {
|
||||
case "PasswordRegisterEnabled":
|
||||
@@ -261,6 +265,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -287,6 +293,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
@@ -352,17 +360,19 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DataExportDefaultTime":
|
||||
common.DataExportDefaultTime = value
|
||||
case "ModelRatio":
|
||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupRatioByJSONString(value)
|
||||
case "GroupGroupRatio":
|
||||
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
@@ -379,6 +389,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -65,14 +65,14 @@ func updatePricing() {
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
}
|
||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
|
||||
@@ -21,6 +21,7 @@ type Redemption struct {
|
||||
Count int `json:"count" gorm:"-:all"` // only for api request
|
||||
UsedUserId int `json:"used_user_id"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint"` // 过期时间,0 表示不过期
|
||||
}
|
||||
|
||||
func GetAllRedemptions(startIdx int, num int) (redemptions []*Redemption, total int64, err error) {
|
||||
@@ -131,6 +132,9 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
if redemption.Status != common.RedemptionCodeStatusEnabled {
|
||||
return errors.New("该兑换码已被使用")
|
||||
}
|
||||
if redemption.ExpiredTime != 0 && redemption.ExpiredTime < common.GetTimestamp() {
|
||||
return errors.New("该兑换码已过期")
|
||||
}
|
||||
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -162,7 +166,7 @@ func (redemption *Redemption) SelectUpdate() error {
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (redemption *Redemption) Update() error {
|
||||
var err error
|
||||
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time").Updates(redemption).Error
|
||||
err = DB.Model(redemption).Select("name", "status", "quota", "redeemed_time", "expired_time").Updates(redemption).Error
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -183,3 +187,9 @@ func DeleteRedemptionById(id int) (err error) {
|
||||
}
|
||||
return redemption.Delete()
|
||||
}
|
||||
|
||||
func DeleteInvalidRedemptions() (int64, error) {
|
||||
now := common.GetTimestamp()
|
||||
result := DB.Where("status IN ? OR (status = ? AND expired_time != 0 AND expired_time < ?)", []int{common.RedemptionCodeStatusUsed, common.RedemptionCodeStatusDisabled}, common.RedemptionCodeStatusEnabled, now).Delete(&Redemption{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -302,3 +302,64 @@ func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, e
|
||||
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
|
||||
return stat, err
|
||||
}
|
||||
|
||||
// TaskCountAllTasks returns total tasks that match the given query params (admin usage)
|
||||
func TaskCountAllTasks(queryParams SyncTaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Task{})
|
||||
if queryParams.ChannelID != "" {
|
||||
query = query.Where("channel_id = ?", queryParams.ChannelID)
|
||||
}
|
||||
if queryParams.Platform != "" {
|
||||
query = query.Where("platform = ?", queryParams.Platform)
|
||||
}
|
||||
if queryParams.UserID != "" {
|
||||
query = query.Where("user_id = ?", queryParams.UserID)
|
||||
}
|
||||
if len(queryParams.UserIDs) != 0 {
|
||||
query = query.Where("user_id in (?)", queryParams.UserIDs)
|
||||
}
|
||||
if queryParams.TaskID != "" {
|
||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||
}
|
||||
if queryParams.Action != "" {
|
||||
query = query.Where("action = ?", queryParams.Action)
|
||||
}
|
||||
if queryParams.Status != "" {
|
||||
query = query.Where("status = ?", queryParams.Status)
|
||||
}
|
||||
if queryParams.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
// TaskCountAllUserTask returns total tasks for given user
|
||||
func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 {
|
||||
var total int64
|
||||
query := DB.Model(&Task{}).Where("user_id = ?", userId)
|
||||
if queryParams.TaskID != "" {
|
||||
query = query.Where("task_id = ?", queryParams.TaskID)
|
||||
}
|
||||
if queryParams.Action != "" {
|
||||
query = query.Where("action = ?", queryParams.Action)
|
||||
}
|
||||
if queryParams.Status != "" {
|
||||
query = query.Where("status = ?", queryParams.Status)
|
||||
}
|
||||
if queryParams.Platform != "" {
|
||||
query = query.Where("platform = ?", queryParams.Platform)
|
||||
}
|
||||
if queryParams.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
|
||||
}
|
||||
if queryParams.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
|
||||
}
|
||||
_ = query.Count(&total).Error
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token
|
||||
if token != "" {
|
||||
token = strings.Trim(token, "sk-")
|
||||
}
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(keyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||
err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func GetTokenByKey(key string, fromDB bool) (token *Token, err error) {
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
err = DB.Where(commonKeyCol+" = ?", key).First(&token).Error
|
||||
return token, err
|
||||
}
|
||||
|
||||
@@ -320,3 +320,10 @@ func decreaseTokenQuota(id int, quota int) (err error) {
|
||||
).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// CountUserTokens returns total number of tokens for the given user, used for pagination
|
||||
func CountUserTokens(userId int) (int64, error) {
|
||||
var total int64
|
||||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func cacheSetToken(token Token) error {
|
||||
key := common.GenerateHMAC(token.Key)
|
||||
token.Clean()
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.TokenCacheSeconds)*time.Second)
|
||||
err := common.RedisHSetObj(fmt.Sprintf("token:%s", key), &token, time.Duration(constant.RedisKeyCacheSeconds())*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ type User struct {
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
LinuxDOId string `json:"linux_do_id" gorm:"column:linux_do_id;index"`
|
||||
Setting string `json:"setting" gorm:"type:text;column:setting"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
}
|
||||
|
||||
func (user *User) ToBaseUser() *UserBase {
|
||||
@@ -175,7 +176,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
||||
// 如果是数字,同时搜索ID和其他字段
|
||||
likeCondition = "id = ? OR " + likeCondition
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
keywordInt, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
@@ -184,7 +185,7 @@ func SearchUsers(keyword string, group string, startIdx int, num int) ([]*User,
|
||||
} else {
|
||||
// 非数字关键字,只搜索字符串字段
|
||||
if group != "" {
|
||||
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?",
|
||||
query = query.Where("("+likeCondition+") AND "+commonGroupCol+" = ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||
} else {
|
||||
query = query.Where(likeCondition,
|
||||
@@ -366,6 +367,7 @@ func (user *User) Edit(updatePassword bool) error {
|
||||
"display_name": newUser.DisplayName,
|
||||
"group": newUser.Group,
|
||||
"quota": newUser.Quota,
|
||||
"remark": newUser.Remark,
|
||||
}
|
||||
if updatePassword {
|
||||
updates["password"] = newUser.Password
|
||||
@@ -615,7 +617,7 @@ func GetUserGroup(id int, fromDB bool) (group string, err error) {
|
||||
// Don't return error - fall through to DB
|
||||
}
|
||||
fromDB = true
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(commonGroupCol).Find(&group).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func updateUserCache(user User) error {
|
||||
return common.RedisHSetObj(
|
||||
getUserCacheKey(user.Id),
|
||||
user.ToBaseUser(),
|
||||
time.Duration(constant.UserId2QuotaCacheSeconds)*time.Second,
|
||||
time.Duration(constant.RedisKeyCacheSeconds())*time.Second,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
// check if there's any data to update
|
||||
hasData := false
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
if len(batchUpdateStores[i]) > 0 {
|
||||
hasData = true
|
||||
batchUpdateLocks[i].Unlock()
|
||||
break
|
||||
}
|
||||
batchUpdateLocks[i].Unlock()
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
|
||||
@@ -109,6 +109,12 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
|
||||
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
// 增加panic恢复处理
|
||||
if r := recover(); r != nil {
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
|
||||
}
|
||||
}
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine stopped.")
|
||||
}
|
||||
@@ -119,19 +125,32 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
// 退出时清理 ticker
|
||||
defer ticker.Stop()
|
||||
// 确保在任何情况下都清理ticker
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping ticker stopped")
|
||||
}
|
||||
}()
|
||||
|
||||
var pingMutex sync.Mutex
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine started")
|
||||
}
|
||||
|
||||
// 增加超时控制,防止goroutine长时间运行
|
||||
maxPingDuration := 120 * time.Minute // 最大ping持续时间
|
||||
pingTimeout := time.NewTimer(maxPingDuration)
|
||||
defer pingTimeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
// 发送 ping 数据
|
||||
case <-ticker.C:
|
||||
if err := sendPingData(c, &pingMutex); err != nil {
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping error, stopping goroutine:", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
// 收到退出信号
|
||||
@@ -140,6 +159,12 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
|
||||
// request 结束
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
// 超时保护,防止goroutine无限运行
|
||||
case <-pingTimeout.C:
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine timeout, stopping")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -148,19 +173,34 @@ func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.Canc
|
||||
}
|
||||
|
||||
func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
// 增加超时控制,防止锁死等待
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
err := helper.PingData(c)
|
||||
if err != nil {
|
||||
common2.LogError(c, "SSE ping error: "+err.Error())
|
||||
err := helper.PingData(c)
|
||||
if err != nil {
|
||||
common2.LogError(c, "SSE ping error: "+err.Error())
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping data sent.")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
// 设置发送ping数据的超时时间
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(10 * time.Second):
|
||||
return errors.New("SSE ping data send timeout")
|
||||
case <-c.Request.Context().Done():
|
||||
return errors.New("request context cancelled during ping")
|
||||
}
|
||||
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping data sent.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
||||
@@ -175,15 +215,23 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
|
||||
var stopPinger context.CancelFunc
|
||||
if info.IsStream {
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
// 处理流式请求的 ping 保活
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
if generalSettings.PingIntervalEnabled {
|
||||
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
||||
stopPinger := startPingKeepAlive(c, pingInterval)
|
||||
defer stopPinger()
|
||||
stopPinger = startPingKeepAlive(c, pingInterval)
|
||||
// 使用defer确保在任何情况下都能停止ping goroutine
|
||||
defer func() {
|
||||
if stopPinger != nil {
|
||||
stopPinger()
|
||||
if common2.DebugEnabled {
|
||||
println("SSE ping goroutine stopped by defer")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
@@ -454,6 +454,7 @@ type ClaudeResponseInfo struct {
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
Done bool
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
@@ -461,20 +462,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
|
||||
// message_start, 获取usage
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
@@ -506,25 +519,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
@@ -544,29 +547,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
//
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
|
||||
@@ -3,7 +3,6 @@ package cohere
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -78,7 +77,7 @@ func stopReasonCohere2OpenAI(reason string) string {
|
||||
}
|
||||
|
||||
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
usage := &dto.Usage{}
|
||||
responseText := ""
|
||||
|
||||
@@ -72,8 +72,11 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
|
||||
@@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct {
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
||||
@@ -35,23 +35,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 检查是否有候选响应
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
@@ -108,7 +95,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
@@ -136,7 +123,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
}
|
||||
|
||||
// 计算最终使用量
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
// usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/model_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -36,6 +37,47 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"video/flv": true,
|
||||
}
|
||||
|
||||
// Gemini 允许的思考预算范围
|
||||
const (
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
flash25LiteMinBudget = 512
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
return flash25LiteMinBudget
|
||||
}
|
||||
if budget > flash25LiteMaxBudget {
|
||||
return flash25LiteMaxBudget
|
||||
}
|
||||
} else if isNew25Pro {
|
||||
if budget < pro25MinBudget {
|
||||
return pro25MinBudget
|
||||
}
|
||||
if budget > pro25MaxBudget {
|
||||
return pro25MaxBudget
|
||||
}
|
||||
} else { // 其他模型
|
||||
if budget < 0 {
|
||||
return 0
|
||||
}
|
||||
if budget > flash25MaxBudget {
|
||||
return flash25MaxBudget
|
||||
}
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
|
||||
@@ -57,16 +99,31 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 硬编码不支持 ThinkingBudget 的旧模型
|
||||
modelName := info.OriginModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.SplitN(modelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-thinking") {
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
|
||||
if strings.HasPrefix(modelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
@@ -78,39 +135,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
} else {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
|
||||
// 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if isNew25Pro {
|
||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
||||
if budgetTokens == 0 || budgetTokens < 128 {
|
||||
budgetTokens = 128
|
||||
} else if budgetTokens > 32768 {
|
||||
budgetTokens = 32768
|
||||
}
|
||||
} else {
|
||||
// 其他模型:ThinkingBudget范围为0-24576
|
||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
||||
budgetTokens = 24576
|
||||
}
|
||||
}
|
||||
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if !isNew25Pro {
|
||||
// 只有非新2.5pro模型才支持-nothinking
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro && !is25FlashLite {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
@@ -283,7 +315,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
||||
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
|
||||
url := part.GetImageMedia().Url
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
@@ -341,7 +374,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
if len(content.Parts) > 0 {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
@@ -611,9 +646,9 @@ func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
|
||||
}
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
|
||||
fullTextResponse := dto.OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
@@ -754,7 +789,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
// responseText := ""
|
||||
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
@@ -849,7 +884,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
|
||||
fullTextResponse.Model = info.UpstreamModelName
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
|
||||
@@ -88,6 +88,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||
|
||||
// 特殊处理 responses API
|
||||
if info.RelayMode == constant.RelayModeResponses {
|
||||
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
|
||||
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
||||
}
|
||||
|
||||
model_ := info.UpstreamModelName
|
||||
// 2025年5月10日后创建的渠道不移除.
|
||||
if info.ChannelCreateTime < constant2.AzureNoRemoveDotTime {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -345,13 +346,14 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
if err = c.ShouldBind(&reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer reqFp.Close()
|
||||
|
||||
tmpFp, err := os.CreateTemp("", "audio-*")
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
@@ -365,7 +367,7 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name())
|
||||
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package palm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -73,7 +72,7 @@ func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompleti
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
responseId := helper.GetResponseID(c)
|
||||
createdTime := common.GetTimestamp()
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
|
||||
@@ -123,14 +123,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
|
||||
@@ -98,7 +98,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
// BudgetTokens 为 max_tokens 的 80%
|
||||
textRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
|
||||
BudgetTokens: common.GetPointer[int](int(float64(textRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
|
||||
}
|
||||
// TODO: 临时处理
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
|
||||
|
||||
@@ -61,6 +61,7 @@ type RelayInfo struct {
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
UserGroup string
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
@@ -204,6 +205,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
FirstResponseTime: startTime.Add(-time.Second),
|
||||
|
||||
@@ -2,14 +2,19 @@ package helper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GroupRatioInfo struct {
|
||||
GroupRatio float64
|
||||
GroupSpecialRatio float64
|
||||
}
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
@@ -17,18 +22,50 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
GroupRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.Group if present
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
|
||||
groupRatioInfo := GroupRatioInfo{
|
||||
GroupRatio: 1.0, // default ratio
|
||||
GroupSpecialRatio: -1,
|
||||
}
|
||||
|
||||
// check auto group
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("final group: %s", autoGroup))
|
||||
}
|
||||
relayInfo.Group = autoGroup.(string)
|
||||
}
|
||||
|
||||
// check user group special ratio
|
||||
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
// user group special ratio
|
||||
groupRatioInfo.GroupSpecialRatio = userGroupRatio
|
||||
groupRatioInfo.GroupRatio = userGroupRatio
|
||||
} else {
|
||||
// normal group ratio
|
||||
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.Group)
|
||||
}
|
||||
|
||||
return groupRatioInfo
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -41,7 +78,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
preConsumedTokens = promptTokens + maxTokens
|
||||
}
|
||||
var success bool
|
||||
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
|
||||
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
|
||||
if !success {
|
||||
acceptUnsetRatio := false
|
||||
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
|
||||
@@ -54,21 +91,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
}
|
||||
|
||||
priceData := PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
@@ -84,11 +121,11 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
}
|
||||
|
||||
func ContainPriceOrRatio(modelName string) bool {
|
||||
_, ok := operation_setting.GetModelPrice(modelName, false)
|
||||
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
_, ok = operation_setting.GetModelRatio(modelName)
|
||||
_, ok = ratio_setting.GetModelRatio(modelName)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package helper
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -19,8 +20,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
InitialScannerBufferSize = 1 << 20 // 1MB (1*1024*1024)
|
||||
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
||||
InitialScannerBufferSize = 64 << 10 // 64KB (64*1024)
|
||||
MaxScannerBufferSize = 10 << 20 // 10MB (10*1024*1024)
|
||||
DefaultPingInterval = 10 * time.Second
|
||||
)
|
||||
|
||||
@@ -30,7 +31,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
return
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
// 确保响应体总是被关闭
|
||||
defer func() {
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second
|
||||
if strings.HasPrefix(info.UpstreamModelName, "o") {
|
||||
@@ -39,11 +45,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
}
|
||||
|
||||
var (
|
||||
stopChan = make(chan bool, 2)
|
||||
stopChan = make(chan bool, 3) // 增加缓冲区避免阻塞
|
||||
scanner = bufio.NewScanner(resp.Body)
|
||||
ticker = time.NewTicker(streamingTimeout)
|
||||
pingTicker *time.Ticker
|
||||
writeMutex sync.Mutex // Mutex to protect concurrent writes
|
||||
wg sync.WaitGroup // 用于等待所有 goroutine 退出
|
||||
)
|
||||
|
||||
generalSettings := operation_setting.GetGeneralSetting()
|
||||
@@ -57,13 +64,32 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
pingTicker = time.NewTicker(pingInterval)
|
||||
}
|
||||
|
||||
// 改进资源清理,确保所有 goroutine 正确退出
|
||||
defer func() {
|
||||
// 通知所有 goroutine 停止
|
||||
common.SafeSendBool(stopChan, true)
|
||||
|
||||
ticker.Stop()
|
||||
if pingTicker != nil {
|
||||
pingTicker.Stop()
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 退出,最多等待5秒
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
common.LogError(c, "timeout waiting for goroutines to exit")
|
||||
}
|
||||
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
SetEventStreamHeaders(c)
|
||||
@@ -73,35 +99,95 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
|
||||
ctx = context.WithValue(ctx, "stop_chan", stopChan)
|
||||
|
||||
// Handle ping data sending
|
||||
// Handle ping data sending with improved error handling
|
||||
if pingEnabled && pingTicker != nil {
|
||||
wg.Add(1)
|
||||
gopool.Go(func() {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
common.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r))
|
||||
common.SafeSendBool(stopChan, true)
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("ping goroutine exited")
|
||||
}
|
||||
}()
|
||||
|
||||
// 添加超时保护,防止 goroutine 无限运行
|
||||
maxPingDuration := 30 * time.Minute // 最大 ping 持续时间
|
||||
pingTimeout := time.NewTimer(maxPingDuration)
|
||||
defer pingTimeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pingTicker.C:
|
||||
writeMutex.Lock() // Lock before writing
|
||||
err := PingData(c)
|
||||
writeMutex.Unlock() // Unlock after writing
|
||||
if err != nil {
|
||||
common.LogError(c, "ping data error: "+err.Error())
|
||||
common.SafeSendBool(stopChan, true)
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
writeMutex.Lock()
|
||||
defer writeMutex.Unlock()
|
||||
done <- PingData(c)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
common.LogError(c, "ping data error: "+err.Error())
|
||||
return
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("ping data sent")
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
common.LogError(c, "ping data send timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
if common.DebugEnabled {
|
||||
println("ping data sent")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
if common.DebugEnabled {
|
||||
println("ping data goroutine stopped")
|
||||
}
|
||||
return
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
// 监听客户端断开连接
|
||||
return
|
||||
case <-pingTimeout.C:
|
||||
common.LogError(c, "ping goroutine max duration reached")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Scanner goroutine with improved error handling
|
||||
wg.Add(1)
|
||||
common.RelayCtxGo(ctx, func() {
|
||||
defer func() {
|
||||
wg.Done()
|
||||
if r := recover(); r != nil {
|
||||
common.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r))
|
||||
}
|
||||
common.SafeSendBool(stopChan, true)
|
||||
if common.DebugEnabled {
|
||||
println("scanner goroutine exited")
|
||||
}
|
||||
}()
|
||||
|
||||
for scanner.Scan() {
|
||||
// 检查是否需要停止
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
ticker.Reset(streamingTimeout)
|
||||
data := scanner.Text()
|
||||
if common.DebugEnabled {
|
||||
@@ -119,11 +205,27 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
info.SetFirstResponseTime()
|
||||
writeMutex.Lock() // Lock before writing
|
||||
success := dataHandler(data)
|
||||
writeMutex.Unlock() // Unlock after writing
|
||||
if !success {
|
||||
break
|
||||
|
||||
// 使用超时机制防止写操作阻塞
|
||||
done := make(chan bool, 1)
|
||||
go func() {
|
||||
writeMutex.Lock()
|
||||
defer writeMutex.Unlock()
|
||||
done <- dataHandler(data)
|
||||
}()
|
||||
|
||||
select {
|
||||
case success := <-done:
|
||||
if !success {
|
||||
return
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
common.LogError(c, "data handler timeout")
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,17 +235,18 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
|
||||
common.LogError(c, "scanner error: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
common.SafeSendBool(stopChan, true)
|
||||
})
|
||||
|
||||
// 主循环等待完成或超时
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// 超时处理逻辑
|
||||
common.LogError(c, "streaming timeout")
|
||||
common.SafeSendBool(stopChan, true)
|
||||
case <-stopChan:
|
||||
// 正常结束
|
||||
common.LogInfo(c, "streaming finished")
|
||||
case <-c.Request.Context().Done():
|
||||
// 客户端断开连接
|
||||
common.LogInfo(c, "client disconnected")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,19 +136,52 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
adaptor.Init(relayInfo)
|
||||
|
||||
// Clean up empty system instruction
|
||||
if req.SystemInstructions != nil {
|
||||
hasContent := false
|
||||
for _, part := range req.SystemInstructions.Parts {
|
||||
if part.Text != "" {
|
||||
hasContent = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasContent {
|
||||
req.SystemInstructions = nil
|
||||
}
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(requestBody))
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||
if openaiErr != nil {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// reset model price
|
||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -174,17 +174,17 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
@@ -480,17 +480,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
|
||||
@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
@@ -361,7 +362,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheRatio := priceData.CacheRatio
|
||||
imageRatio := priceData.ImageRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
@@ -510,7 +511,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
|
||||
@@ -15,8 +15,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -38,9 +37,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
@@ -49,7 +48,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/setup", controller.GetSetup)
|
||||
apiRouter.POST("/setup", controller.PostSetup)
|
||||
apiRouter.GET("/status", controller.GetStatus)
|
||||
apiRouter.GET("/uptime/status", controller.GetUptimeKumaStatus)
|
||||
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
|
||||
apiRouter.GET("/status/test", middleware.AdminAuth(), controller.TestStatus)
|
||||
apiRouter.GET("/notice", controller.GetNotice)
|
||||
@@ -80,6 +81,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
optionRoute.GET("/", controller.GetOptions)
|
||||
optionRoute.PUT("/", controller.UpdateOption)
|
||||
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||
}
|
||||
channelRoute := apiRouter.Group("/channel")
|
||||
channelRoute.Use(middleware.AdminAuth())
|
||||
@@ -125,6 +127,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
redemptionRoute.GET("/:id", controller.GetRedemption)
|
||||
redemptionRoute.POST("/", controller.AddRedemption)
|
||||
redemptionRoute.PUT("/", controller.UpdateRedemption)
|
||||
redemptionRoute.DELETE("/invalid", controller.DeleteInvalidRedemption)
|
||||
redemptionRoute.DELETE("/:id", controller.DeleteRedemption)
|
||||
}
|
||||
logRoute := apiRouter.Group("/log")
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
func SetRelayRouter(router *gin.Engine) {
|
||||
router.Use(middleware.CORS())
|
||||
router.Use(middleware.DecompressRequestMiddleware())
|
||||
router.Use(middleware.StatsMiddleware())
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
modelsRouter := router.Group("/v1/models")
|
||||
modelsRouter.Use(middleware.TokenAuth())
|
||||
|
||||
@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
||||
return true
|
||||
case "billing_not_active":
|
||||
return true
|
||||
case "pre_consume_token_quota_failed":
|
||||
return true
|
||||
}
|
||||
switch err.Error.Type {
|
||||
case "insufficient_quota":
|
||||
|
||||
@@ -21,10 +21,10 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
|
||||
|
||||
isOpenRouter := info.ChannelType == common.ChannelTypeOpenRouter
|
||||
|
||||
if claudeRequest.Thinking != nil {
|
||||
if claudeRequest.Thinking != nil && claudeRequest.Thinking.Type == "enabled" {
|
||||
if isOpenRouter {
|
||||
reasoning := openrouter.RequestReasoning{
|
||||
MaxTokens: claudeRequest.Thinking.BudgetTokens,
|
||||
MaxTokens: claudeRequest.Thinking.GetBudgetTokens(),
|
||||
}
|
||||
reasoningJSON, err := json.Marshal(reasoning)
|
||||
if err != nil {
|
||||
|
||||
@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
|
||||
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: text,
|
||||
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
|
||||
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
if !strings.HasPrefix(lowerText, "get file base64 from url") {
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
claudeError := dto.ClaudeError{
|
||||
Message: text,
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
// Convert to base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if len(strings.Split(mimeType, ";")) > 1 {
|
||||
// If Content-Type has parameters, take the first part
|
||||
mimeType = strings.Split(mimeType, ";")[0]
|
||||
}
|
||||
if mimeType == "application/octet-stream" {
|
||||
if common.DebugEnabled {
|
||||
println("MIME type is application/octet-stream, trying to guess from URL or filename")
|
||||
}
|
||||
// try to guess the MIME type from the url last segment
|
||||
urlParts := strings.Split(url, "/")
|
||||
if len(urlParts) > 0 {
|
||||
lastSegment := urlParts[len(urlParts)-1]
|
||||
if strings.Contains(lastSegment, ".") {
|
||||
// Extract the file extension
|
||||
filename := strings.Split(lastSegment, ".")
|
||||
if len(filename) > 1 {
|
||||
ext := strings.ToLower(filename[len(filename)-1])
|
||||
// Guess MIME type based on file extension
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// try to guess the MIME type from the file extension
|
||||
fileName := resp.Header.Get("Content-Disposition")
|
||||
if fileName != "" {
|
||||
// Extract the filename from the Content-Disposition header
|
||||
parts := strings.Split(fileName, ";")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
|
||||
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
||||
// Remove quotes if present
|
||||
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
|
||||
fileName = fileName[1 : len(fileName)-1]
|
||||
}
|
||||
// Guess MIME type based on file extension
|
||||
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: resp.Header.Get("Content-Type"),
|
||||
MimeType: mimeType,
|
||||
Size: int64(len(fileBytes)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetMimeTypeByExtension(ext string) string {
|
||||
// Convert to lowercase for case-insensitive comparison
|
||||
ext = strings.ToLower(ext)
|
||||
switch ext {
|
||||
// Text files
|
||||
case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
|
||||
return "text/plain"
|
||||
|
||||
// Image files
|
||||
case "jpg", "jpeg":
|
||||
return "image/jpeg"
|
||||
case "png":
|
||||
return "image/png"
|
||||
case "gif":
|
||||
return "image/gif"
|
||||
|
||||
// Audio files
|
||||
case "mp3":
|
||||
return "audio/mp3"
|
||||
case "wav":
|
||||
return "audio/wav"
|
||||
case "mpeg":
|
||||
return "audio/mpeg"
|
||||
|
||||
// Video files
|
||||
case "mp4":
|
||||
return "video/mp4"
|
||||
case "wmv":
|
||||
return "video/wmv"
|
||||
case "flv":
|
||||
return "video/flv"
|
||||
case "mov":
|
||||
return "video/mov"
|
||||
case "mpg":
|
||||
return "video/mpg"
|
||||
case "avi":
|
||||
return "video/avi"
|
||||
case "mpegps":
|
||||
return "video/mpegps"
|
||||
|
||||
// Document files
|
||||
case "pdf":
|
||||
return "application/pdf"
|
||||
|
||||
default:
|
||||
return "application/octet-stream" // Default for unknown types
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
|
||||
cacheTokens int, cacheRatio float64, modelPrice float64) map[string]interface{} {
|
||||
cacheTokens int, cacheRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_ratio"] = modelRatio
|
||||
other["group_ratio"] = groupRatio
|
||||
@@ -16,6 +16,7 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
other["cache_tokens"] = cacheTokens
|
||||
other["cache_ratio"] = cacheRatio
|
||||
other["model_price"] = modelPrice
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
other["frt"] = float64(relayInfo.FirstResponseTime.UnixMilli() - relayInfo.StartTime.UnixMilli())
|
||||
if relayInfo.ReasoningEffort != "" {
|
||||
other["reasoning_effort"] = relayInfo.ReasoningEffort
|
||||
@@ -30,8 +31,8 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
return other
|
||||
}
|
||||
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
|
||||
func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
|
||||
info["ws"] = true
|
||||
info["audio_input"] = usage.InputTokenDetails.AudioTokens
|
||||
info["audio_output"] = usage.OutputTokenDetails.AudioTokens
|
||||
@@ -42,8 +43,8 @@ func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, us
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice)
|
||||
func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice, userGroupRatio float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, 0, 0.0, modelPrice, userGroupRatio)
|
||||
info["audio"] = true
|
||||
info["audio_input"] = usage.PromptTokensDetails.AudioTokens
|
||||
info["audio_output"] = usage.CompletionTokenDetails.AudioTokens
|
||||
@@ -55,8 +56,8 @@ func GenerateAudioOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
|
||||
func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelRatio, groupRatio, completionRatio float64,
|
||||
cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice)
|
||||
cacheTokens int, cacheRatio float64, cacheCreationTokens int, cacheCreationRatio float64, modelPrice float64, userGroupRatio float64) map[string]interface{} {
|
||||
info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
|
||||
info["claude"] = true
|
||||
info["cache_creation_tokens"] = cacheCreationTokens
|
||||
info["cache_creation_ratio"] = cacheCreationRatio
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -10,7 +11,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -45,9 +46,9 @@ func calculateAudioQuota(info QuotaInfo) int {
|
||||
return int(quota.IntPart())
|
||||
}
|
||||
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
|
||||
|
||||
groupRatio := decimal.NewFromFloat(info.GroupRatio)
|
||||
modelRatio := decimal.NewFromFloat(info.ModelRatio)
|
||||
@@ -93,8 +94,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
modelRatio, _ := operation_setting.GetModelRatio(modelName)
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.Group)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
|
||||
log.Printf("final group ratio: %f", groupRatio)
|
||||
relayInfo.Group = autoGroup.(string)
|
||||
}
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -108,7 +122,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
ModelName: modelName,
|
||||
UsePrice: relayInfo.UsePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: groupRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -130,8 +144,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
}
|
||||
|
||||
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
@@ -141,9 +154,14 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -189,7 +207,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
@@ -205,9 +223,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := priceData.CompletionRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
|
||||
cacheRatio := priceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
@@ -256,7 +273,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice)
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
@@ -272,12 +289,12 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
@@ -333,7 +350,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
"image"
|
||||
"log"
|
||||
"math"
|
||||
@@ -11,78 +13,63 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
var o200kTokenEncoder *tiktoken.Tiktoken
|
||||
var defaultTokenEncoder tokenizer.Codec
|
||||
|
||||
// tokenEncoderMap is used to store token encoders for different models
|
||||
var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
|
||||
// tokenEncoderMutex protects tokenEncoderMap for concurrent access
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = cl100TokenEncoder
|
||||
o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range operation_setting.GetDefaultModelRatioMap() {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = cl100TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
if strings.HasPrefix(model, "gpt-4o") {
|
||||
tokenEncoderMap[model] = o200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
}
|
||||
} else if strings.HasPrefix(model, "o") {
|
||||
tokenEncoderMap[model] = o200kTokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
}
|
||||
}
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
|
||||
return o200kTokenEncoder
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
// First, try to get the encoder from cache with read lock
|
||||
tokenEncoderMutex.RLock()
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
tokenEncoderMutex.RUnlock()
|
||||
return encoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
tokenEncoderMutex.RUnlock()
|
||||
|
||||
// If not in cache, create new encoder with write lock
|
||||
tokenEncoderMutex.Lock()
|
||||
defer tokenEncoderMutex.Unlock()
|
||||
|
||||
// Double-check if another goroutine already created the encoder
|
||||
if encoder, exists := tokenEncoderMap[model]; exists {
|
||||
return encoder
|
||||
}
|
||||
|
||||
// Create new encoder
|
||||
modelCodec, err := tokenizer.ForModel(tokenizer.Model(model))
|
||||
if err != nil {
|
||||
// Cache the default encoder for this model to avoid repeated failures
|
||||
tokenEncoderMap[model] = defaultTokenEncoder
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
// Cache the new encoder
|
||||
tokenEncoderMap[model] = modelCodec
|
||||
return modelCodec
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
return tokenEncoder
|
||||
}
|
||||
// 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = getModelDefaultTokenEncoder(model)
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
// 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder
|
||||
return getModelDefaultTokenEncoder(model)
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
tkm, _ := tokenEncoder.Count(text)
|
||||
return tkm
|
||||
}
|
||||
|
||||
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
package setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidateApiInfo 验证API信息格式
|
||||
func ValidateApiInfo(apiInfoStr string) error {
|
||||
if apiInfoStr == "" {
|
||||
return nil // 空字符串是合法的
|
||||
}
|
||||
|
||||
var apiInfoList []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(apiInfoStr), &apiInfoList); err != nil {
|
||||
return fmt.Errorf("API信息格式错误:%s", err.Error())
|
||||
}
|
||||
|
||||
// 验证数组长度
|
||||
if len(apiInfoList) > 50 {
|
||||
return fmt.Errorf("API信息数量不能超过50个")
|
||||
}
|
||||
|
||||
// 允许的颜色值
|
||||
validColors := map[string]bool{
|
||||
"blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
|
||||
"red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
|
||||
"light-green": true, "teal": true, "light-blue": true, "indigo": true,
|
||||
"violet": true, "grey": true,
|
||||
}
|
||||
|
||||
// URL正则表达式
|
||||
urlRegex := regexp.MustCompile(`^https?://[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*(/.*)?$`)
|
||||
|
||||
for i, apiInfo := range apiInfoList {
|
||||
// 检查必填字段
|
||||
urlStr, ok := apiInfo["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
|
||||
}
|
||||
|
||||
route, ok := apiInfo["route"].(string)
|
||||
if !ok || route == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
|
||||
}
|
||||
|
||||
description, ok := apiInfo["description"].(string)
|
||||
if !ok || description == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
|
||||
}
|
||||
|
||||
color, ok := apiInfo["color"].(string)
|
||||
if !ok || color == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
|
||||
}
|
||||
|
||||
// 验证URL格式
|
||||
if !urlRegex.MatchString(urlStr) {
|
||||
return fmt.Errorf("第%d个API信息的URL格式不正确", i+1)
|
||||
}
|
||||
|
||||
// 验证URL可解析性
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return fmt.Errorf("第%d个API信息的URL无法解析:%s", i+1, err.Error())
|
||||
}
|
||||
|
||||
// 验证字段长度
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
|
||||
if len(route) > 100 {
|
||||
return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
|
||||
}
|
||||
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
|
||||
// 验证颜色值
|
||||
if !validColors[color] {
|
||||
return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
|
||||
}
|
||||
|
||||
// 检查并过滤危险字符(防止XSS)
|
||||
dangerousChars := []string{"<script", "<iframe", "javascript:", "onload=", "onerror=", "onclick="}
|
||||
for _, dangerous := range dangerousChars {
|
||||
if strings.Contains(strings.ToLower(description), dangerous) {
|
||||
return fmt.Errorf("第%d个API信息的说明包含不允许的内容", i+1)
|
||||
}
|
||||
if strings.Contains(strings.ToLower(route), dangerous) {
|
||||
return fmt.Errorf("第%d个API信息的线路描述包含不允许的内容", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetApiInfo 获取API信息列表
|
||||
func GetApiInfo() []map[string]interface{} {
|
||||
// 从OptionMap中获取API信息,如果不存在则返回空数组
|
||||
common.OptionMapRWMutex.RLock()
|
||||
apiInfoStr, exists := common.OptionMap["ApiInfo"]
|
||||
common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
if !exists || apiInfoStr == "" {
|
||||
// 如果没有配置,返回空数组
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
// 解析存储的API信息
|
||||
var apiInfo []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(apiInfoStr), &apiInfo); err != nil {
|
||||
// 如果解析失败,返回空数组
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
|
||||
return apiInfo
|
||||
}
|
||||
31
setting/auto_group.go
Normal file
31
setting/auto_group.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package setting
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var AutoGroups = []string{
|
||||
"default",
|
||||
}
|
||||
|
||||
var DefaultUseAutoGroup = false
|
||||
|
||||
func ContainsAutoGroup(group string) bool {
|
||||
for _, autoGroup := range AutoGroups {
|
||||
if autoGroup == group {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func UpdateAutoGroupsByJsonString(jsonString string) error {
|
||||
AutoGroups = make([]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &AutoGroups)
|
||||
}
|
||||
|
||||
func AutoGroups2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(AutoGroups)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
39
setting/console_setting/config.go
Normal file
39
setting/console_setting/config.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package console_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type ConsoleSetting struct {
|
||||
ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
|
||||
UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
|
||||
Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
|
||||
FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
|
||||
ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
|
||||
UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
|
||||
AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
|
||||
FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultConsoleSetting = ConsoleSetting{
|
||||
ApiInfo: "",
|
||||
UptimeKumaGroups: "",
|
||||
Announcements: "",
|
||||
FAQ: "",
|
||||
ApiInfoEnabled: true,
|
||||
UptimeKumaEnabled: true,
|
||||
AnnouncementsEnabled: true,
|
||||
FAQEnabled: true,
|
||||
}
|
||||
|
||||
// 全局实例
|
||||
var consoleSetting = defaultConsoleSetting
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器,键名为 console_setting
|
||||
config.GlobalConfig.Register("console_setting", &consoleSetting)
|
||||
}
|
||||
|
||||
// GetConsoleSetting 获取 ConsoleSetting 配置实例
|
||||
func GetConsoleSetting() *ConsoleSetting {
|
||||
return &consoleSetting
|
||||
}
|
||||
304
setting/console_setting/validation.go
Normal file
304
setting/console_setting/validation.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package console_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"sort"
|
||||
)
|
||||
|
||||
var (
|
||||
urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
|
||||
dangerousChars = []string{"<script", "<iframe", "javascript:", "onload=", "onerror=", "onclick="}
|
||||
validColors = map[string]bool{
|
||||
"blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
|
||||
"red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
|
||||
"light-green": true, "teal": true, "light-blue": true, "indigo": true,
|
||||
"violet": true, "grey": true,
|
||||
}
|
||||
slugRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
)
|
||||
|
||||
func parseJSONArray(jsonStr string, typeName string) ([]map[string]interface{}, error) {
|
||||
var list []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &list); err != nil {
|
||||
return nil, fmt.Errorf("%s格式错误:%s", typeName, err.Error())
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func validateURL(urlStr string, index int, itemType string) error {
|
||||
if !urlRegex.MatchString(urlStr) {
|
||||
return fmt.Errorf("第%d个%s的URL格式不正确", index, itemType)
|
||||
}
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return fmt.Errorf("第%d个%s的URL无法解析:%s", index, itemType, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDangerousContent(content string, index int, itemType string) error {
|
||||
lower := strings.ToLower(content)
|
||||
for _, d := range dangerousChars {
|
||||
if strings.Contains(lower, d) {
|
||||
return fmt.Errorf("第%d个%s包含不允许的内容", index, itemType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getJSONList(jsonStr string) []map[string]interface{} {
|
||||
if jsonStr == "" {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
var list []map[string]interface{}
|
||||
json.Unmarshal([]byte(jsonStr), &list)
|
||||
return list
|
||||
}
|
||||
|
||||
func ValidateConsoleSettings(settingsStr string, settingType string) error {
|
||||
if settingsStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch settingType {
|
||||
case "ApiInfo":
|
||||
return validateApiInfo(settingsStr)
|
||||
case "Announcements":
|
||||
return validateAnnouncements(settingsStr)
|
||||
case "FAQ":
|
||||
return validateFAQ(settingsStr)
|
||||
case "UptimeKumaGroups":
|
||||
return validateUptimeKumaGroups(settingsStr)
|
||||
default:
|
||||
return fmt.Errorf("未知的设置类型:%s", settingType)
|
||||
}
|
||||
}
|
||||
|
||||
func validateApiInfo(apiInfoStr string) error {
|
||||
apiInfoList, err := parseJSONArray(apiInfoStr, "API信息")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(apiInfoList) > 50 {
|
||||
return fmt.Errorf("API信息数量不能超过50个")
|
||||
}
|
||||
|
||||
for i, apiInfo := range apiInfoList {
|
||||
urlStr, ok := apiInfo["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
|
||||
}
|
||||
route, ok := apiInfo["route"].(string)
|
||||
if !ok || route == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
|
||||
}
|
||||
description, ok := apiInfo["description"].(string)
|
||||
if !ok || description == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
|
||||
}
|
||||
color, ok := apiInfo["color"].(string)
|
||||
if !ok || color == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
|
||||
}
|
||||
|
||||
if err := validateURL(urlStr, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
if len(route) > 100 {
|
||||
return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
|
||||
}
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
|
||||
if !validColors[color] {
|
||||
return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
|
||||
}
|
||||
|
||||
if err := checkDangerousContent(description, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDangerousContent(route, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetApiInfo() []map[string]interface{} {
|
||||
return getJSONList(GetConsoleSetting().ApiInfo)
|
||||
}
|
||||
|
||||
func validateAnnouncements(announcementsStr string) error {
|
||||
list, err := parseJSONArray(announcementsStr, "系统公告")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 100 {
|
||||
return fmt.Errorf("系统公告数量不能超过100个")
|
||||
}
|
||||
validTypes := map[string]bool{
|
||||
"default": true, "ongoing": true, "success": true, "warning": true, "error": true,
|
||||
}
|
||||
for i, ann := range list {
|
||||
content, ok := ann["content"].(string)
|
||||
if !ok || content == "" {
|
||||
return fmt.Errorf("第%d个公告缺少内容字段", i+1)
|
||||
}
|
||||
publishDateAny, exists := ann["publishDate"]
|
||||
if !exists {
|
||||
return fmt.Errorf("第%d个公告缺少发布日期字段", i+1)
|
||||
}
|
||||
publishDateStr, ok := publishDateAny.(string)
|
||||
if !ok || publishDateStr == "" {
|
||||
return fmt.Errorf("第%d个公告的发布日期不能为空", i+1)
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil {
|
||||
return fmt.Errorf("第%d个公告的发布日期格式错误", i+1)
|
||||
}
|
||||
if t, exists := ann["type"]; exists {
|
||||
if typeStr, ok := t.(string); ok {
|
||||
if !validTypes[typeStr] {
|
||||
return fmt.Errorf("第%d个公告的类型值不合法", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(content) > 500 {
|
||||
return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1)
|
||||
}
|
||||
if extra, exists := ann["extra"]; exists {
|
||||
if extraStr, ok := extra.(string); ok && len(extraStr) > 200 {
|
||||
return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateFAQ(faqStr string) error {
|
||||
list, err := parseJSONArray(faqStr, "FAQ信息")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 100 {
|
||||
return fmt.Errorf("FAQ数量不能超过100个")
|
||||
}
|
||||
for i, faq := range list {
|
||||
question, ok := faq["question"].(string)
|
||||
if !ok || question == "" {
|
||||
return fmt.Errorf("第%d个FAQ缺少问题字段", i+1)
|
||||
}
|
||||
answer, ok := faq["answer"].(string)
|
||||
if !ok || answer == "" {
|
||||
return fmt.Errorf("第%d个FAQ缺少答案字段", i+1)
|
||||
}
|
||||
if len(question) > 200 {
|
||||
return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1)
|
||||
}
|
||||
if len(answer) > 1000 {
|
||||
return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPublishTime(item map[string]interface{}) time.Time {
|
||||
if v, ok := item["publishDate"]; ok {
|
||||
if s, ok2 := v.(string); ok2 {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func GetAnnouncements() []map[string]interface{} {
|
||||
list := getJSONList(GetConsoleSetting().Announcements)
|
||||
sort.SliceStable(list, func(i, j int) bool {
|
||||
return getPublishTime(list[i]).After(getPublishTime(list[j]))
|
||||
})
|
||||
return list
|
||||
}
|
||||
|
||||
func GetFAQ() []map[string]interface{} {
|
||||
return getJSONList(GetConsoleSetting().FAQ)
|
||||
}
|
||||
|
||||
func validateUptimeKumaGroups(groupsStr string) error {
|
||||
groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(groups) > 20 {
|
||||
return fmt.Errorf("Uptime Kuma分组数量不能超过20个")
|
||||
}
|
||||
|
||||
nameSet := make(map[string]bool)
|
||||
|
||||
for i, group := range groups {
|
||||
categoryName, ok := group["categoryName"].(string)
|
||||
if !ok || categoryName == "" {
|
||||
return fmt.Errorf("第%d个分组缺少分类名称字段", i+1)
|
||||
}
|
||||
if nameSet[categoryName] {
|
||||
return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1)
|
||||
}
|
||||
nameSet[categoryName] = true
|
||||
urlStr, ok := group["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个分组缺少URL字段", i+1)
|
||||
}
|
||||
slug, ok := group["slug"].(string)
|
||||
if !ok || slug == "" {
|
||||
return fmt.Errorf("第%d个分组缺少Slug字段", i+1)
|
||||
}
|
||||
description, ok := group["description"].(string)
|
||||
if !ok {
|
||||
description = ""
|
||||
}
|
||||
|
||||
if err := validateURL(urlStr, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(categoryName) > 50 {
|
||||
return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1)
|
||||
}
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
if len(slug) > 100 {
|
||||
return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1)
|
||||
}
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1)
|
||||
}
|
||||
|
||||
if !slugRegex.MatchString(slug) {
|
||||
return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1)
|
||||
}
|
||||
|
||||
if err := checkDangerousContent(description, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUptimeKumaGroups() []map[string]interface{} {
|
||||
return getJSONList(GetConsoleSetting().UptimeKumaGroups)
|
||||
}
|
||||
@@ -1,8 +1,45 @@
|
||||
package setting
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var MinTopUp = 1
|
||||
|
||||
var PayMethods = []map[string]string{
|
||||
{
|
||||
"name": "支付宝",
|
||||
"color": "rgba(var(--semi-blue-5), 1)",
|
||||
"type": "zfb",
|
||||
},
|
||||
{
|
||||
"name": "微信",
|
||||
"color": "rgba(var(--semi-green-5), 1)",
|
||||
"type": "wx",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||
PayMethods = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
}
|
||||
|
||||
func PayMethods2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(PayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func ContainsPayMethod(method string) bool {
|
||||
for _, payMethod := range PayMethods {
|
||||
if payMethod["type"] == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package operation_setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -14,10 +14,19 @@ var groupRatio = map[string]float64{
|
||||
}
|
||||
var groupRatioMutex sync.RWMutex
|
||||
|
||||
var (
|
||||
GroupGroupRatio = map[string]map[string]float64{
|
||||
"vip": {
|
||||
"edit_this": 0.9,
|
||||
},
|
||||
}
|
||||
groupGroupRatioMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func GetGroupRatioCopy() map[string]float64 {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
|
||||
groupRatioCopy := make(map[string]float64)
|
||||
for k, v := range groupRatio {
|
||||
groupRatioCopy[k] = v
|
||||
@@ -28,7 +37,7 @@ func GetGroupRatioCopy() map[string]float64 {
|
||||
func ContainsGroupRatio(name string) bool {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
|
||||
_, ok := groupRatio[name]
|
||||
return ok
|
||||
}
|
||||
@@ -36,7 +45,7 @@ func ContainsGroupRatio(name string) bool {
|
||||
func GroupRatio2JSONString() string {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
|
||||
jsonBytes, err := json.Marshal(groupRatio)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling model ratio: " + err.Error())
|
||||
@@ -47,7 +56,7 @@ func GroupRatio2JSONString() string {
|
||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||
groupRatioMutex.Lock()
|
||||
defer groupRatioMutex.Unlock()
|
||||
|
||||
|
||||
groupRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &groupRatio)
|
||||
}
|
||||
@@ -55,7 +64,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
|
||||
func GetGroupRatio(name string) float64 {
|
||||
groupRatioMutex.RLock()
|
||||
defer groupRatioMutex.RUnlock()
|
||||
|
||||
|
||||
ratio, ok := groupRatio[name]
|
||||
if !ok {
|
||||
common.SysError("group ratio not found: " + name)
|
||||
@@ -64,6 +73,40 @@ func GetGroupRatio(name string) float64 {
|
||||
return ratio
|
||||
}
|
||||
|
||||
func GetGroupGroupRatio(group, name string) (float64, bool) {
|
||||
groupGroupRatioMutex.RLock()
|
||||
defer groupGroupRatioMutex.RUnlock()
|
||||
|
||||
gp, ok := GroupGroupRatio[group]
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
ratio, ok := gp[name]
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
return ratio, true
|
||||
}
|
||||
|
||||
func GroupGroupRatio2JSONString() string {
|
||||
groupGroupRatioMutex.RLock()
|
||||
defer groupGroupRatioMutex.RUnlock()
|
||||
|
||||
jsonBytes, err := json.Marshal(GroupGroupRatio)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling group-group ratio: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateGroupGroupRatioByJSONString(jsonStr string) error {
|
||||
groupGroupRatioMutex.Lock()
|
||||
defer groupGroupRatioMutex.Unlock()
|
||||
|
||||
GroupGroupRatio = make(map[string]map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &GroupGroupRatio)
|
||||
}
|
||||
|
||||
func CheckGroupRatio(jsonStr string) error {
|
||||
checkGroupRatio := make(map[string]float64)
|
||||
err := json.Unmarshal([]byte(jsonStr), &checkGroupRatio)
|
||||
@@ -1,8 +1,9 @@
|
||||
package operation_setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -142,6 +143,11 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-2.5-flash-preview-04-17": 0.075,
|
||||
"gemini-2.5-flash-preview-04-17-thinking": 0.075,
|
||||
"gemini-2.5-flash-preview-04-17-nothinking": 0.075,
|
||||
"gemini-2.5-flash-preview-05-20": 0.075,
|
||||
"gemini-2.5-flash-preview-05-20-thinking": 0.075,
|
||||
"gemini-2.5-flash-preview-05-20-nothinking": 0.075,
|
||||
"gemini-2.5-flash-thinking-*": 0.075, // 用于为后续所有2.5 flash thinking budget 模型设置默认倍率
|
||||
"gemini-2.5-pro-thinking-*": 0.625, // 用于为后续所有2.5 pro thinking budget 模型设置默认倍率
|
||||
"text-embedding-004": 0.001,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
@@ -342,16 +348,26 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
||||
}
|
||||
|
||||
// 处理带有思考预算的模型名称,方便统一定价
|
||||
func handleThinkingBudgetModel(name, prefix, wildcard string) string {
|
||||
if strings.HasPrefix(name, prefix) && strings.Contains(name, "-thinking-") {
|
||||
return wildcard
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) (float64, bool) {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
|
||||
name = handleThinkingBudgetModel(name, "gemini-2.5-flash", "gemini-2.5-flash-thinking-*")
|
||||
name = handleThinkingBudgetModel(name, "gemini-2.5-pro", "gemini-2.5-pro-thinking-*")
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
ratio, ok := modelRatioMap[name]
|
||||
if !ok {
|
||||
return 37.5, SelfUseModeEnabled
|
||||
return 37.5, operation_setting.SelfUseModeEnabled
|
||||
}
|
||||
return ratio, true
|
||||
}
|
||||
@@ -470,9 +486,9 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
return 4, true
|
||||
} else if strings.HasPrefix(name, "gemini-2.0") {
|
||||
return 4, true
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-pro-preview") {
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-pro") { // 移除preview来增加兼容性,这里假设正式版的倍率和preview一致
|
||||
return 8, true
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-flash-preview") {
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-flash") { // 同上
|
||||
if strings.HasSuffix(name, "-nothinking") {
|
||||
return 4, false
|
||||
} else {
|
||||
@@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool {
|
||||
_, ok := userUsableGroups[groupName]
|
||||
return ok
|
||||
}
|
||||
|
||||
func GetUsableGroupDescription(groupName string) string {
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
return desc
|
||||
}
|
||||
return groupName
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
# React Template
|
||||
|
||||
## Basic Usages
|
||||
|
||||
```shell
|
||||
# Runs the app in the development mode
|
||||
npm start
|
||||
|
||||
# Builds the app for production to the `build` folder
|
||||
npm run build
|
||||
```
|
||||
|
||||
If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build,
|
||||
for example: `REACT_APP_SERVER=http://your.domain.com`.
|
||||
|
||||
Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled.
|
||||
|
||||
## Reference
|
||||
|
||||
1. https://github.com/OIerDb-ng/OIerDb
|
||||
2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example
|
||||
@@ -37,8 +37,6 @@
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"semantic-ui-offline": "^2.5.0",
|
||||
"semantic-ui-react": "^2.1.3",
|
||||
"sse.js": "^2.6.0",
|
||||
"unist-util-visit": "^5.0.0",
|
||||
"use-debounce": "^10.0.4"
|
||||
|
||||
5584
web/pnpm-lock.yaml
generated
5584
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -363,7 +363,7 @@ const HeaderBar = () => {
|
||||
onClose={() => setNoticeVisible(false)}
|
||||
isMobile={styleState.isMobile}
|
||||
/>
|
||||
<div className="w-full px-4">
|
||||
<div className="w-full px-2">
|
||||
<div className="flex items-center justify-between h-16">
|
||||
<div className="flex items-center">
|
||||
<div className="md:hidden">
|
||||
|
||||
@@ -64,11 +64,7 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
|
||||
return (
|
||||
<div
|
||||
dangerouslySetInnerHTML={{ __html: noticeContent }}
|
||||
className="max-h-[60vh] overflow-y-auto pr-2"
|
||||
style={{
|
||||
scrollbarWidth: 'thin',
|
||||
scrollbarColor: 'var(--semi-color-tertiary) transparent'
|
||||
}}
|
||||
className="notice-content-scroll max-h-[60vh] overflow-y-auto pr-2"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -134,11 +134,10 @@ const PageLayout = () => {
|
||||
<Content
|
||||
style={{
|
||||
flex: '1 0 auto',
|
||||
overflowY: styleState.isMobile ? 'visible' : 'auto',
|
||||
overflowY: styleState.isMobile ? 'visible' : 'hidden',
|
||||
WebkitOverflowScrolling: 'touch',
|
||||
padding: shouldInnerPadding ? '24px' : '0',
|
||||
padding: shouldInnerPadding ? (styleState.isMobile ? '5px' : '24px') : '0',
|
||||
position: 'relative',
|
||||
marginTop: styleState.isMobile ? '2px' : '0',
|
||||
}}
|
||||
>
|
||||
<App />
|
||||
|
||||
@@ -1,14 +1,32 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin } from '@douyinfe/semi-ui';
|
||||
import { API, showError } from '../../helpers';
|
||||
import React, { useEffect, useState, useMemo } from 'react';
|
||||
import { Card, Spin, Button, Modal } from '@douyinfe/semi-ui';
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
import SettingsAPIInfo from '../../pages/Setting/Dashboard/SettingsAPIInfo.js';
|
||||
import SettingsAnnouncements from '../../pages/Setting/Dashboard/SettingsAnnouncements.js';
|
||||
import SettingsFAQ from '../../pages/Setting/Dashboard/SettingsFAQ.js';
|
||||
import SettingsUptimeKuma from '../../pages/Setting/Dashboard/SettingsUptimeKuma.js';
|
||||
|
||||
const DashboardSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
'console_setting.api_info': '',
|
||||
'console_setting.announcements': '',
|
||||
'console_setting.faq': '',
|
||||
'console_setting.uptime_kuma_groups': '',
|
||||
'console_setting.api_info_enabled': '',
|
||||
'console_setting.announcements_enabled': '',
|
||||
'console_setting.faq_enabled': '',
|
||||
'console_setting.uptime_kuma_enabled': '',
|
||||
|
||||
// 用于迁移检测的旧键,下个版本会删除
|
||||
ApiInfo: '',
|
||||
Announcements: '',
|
||||
FAQ: '',
|
||||
UptimeKumaUrl: '',
|
||||
UptimeKumaSlug: '',
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
const [showMigrateModal, setShowMigrateModal] = useState(false); // 下个版本会删除
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
@@ -42,13 +60,71 @@ const DashboardSetting = () => {
|
||||
onRefresh();
|
||||
}, []);
|
||||
|
||||
// 用于迁移检测的旧键,下个版本会删除
|
||||
const hasLegacyData = useMemo(() => {
|
||||
const legacyKeys = ['ApiInfo', 'Announcements', 'FAQ', 'UptimeKumaUrl', 'UptimeKumaSlug'];
|
||||
return legacyKeys.some(k => inputs[k]);
|
||||
}, [inputs]);
|
||||
|
||||
useEffect(() => {
|
||||
if (hasLegacyData) {
|
||||
setShowMigrateModal(true);
|
||||
}
|
||||
}, [hasLegacyData]);
|
||||
|
||||
const handleMigrate = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
await API.post('/api/option/migrate_console_setting');
|
||||
showSuccess('旧配置迁移完成');
|
||||
await onRefresh();
|
||||
setShowMigrateModal(false);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
showError('迁移失败: ' + (err.message || '未知错误'));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 用于迁移检测的旧键模态框,下个版本会删除 */}
|
||||
<Modal
|
||||
title="配置迁移确认"
|
||||
visible={showMigrateModal}
|
||||
onOk={handleMigrate}
|
||||
onCancel={() => setShowMigrateModal(false)}
|
||||
confirmLoading={loading}
|
||||
okText="确认迁移"
|
||||
cancelText="取消"
|
||||
>
|
||||
<p>检测到旧版本的配置数据,是否要迁移到新的配置格式?</p>
|
||||
<p style={{ color: '#f57c00', marginTop: '10px' }}>
|
||||
<strong>注意:</strong>迁移过程中会自动处理数据格式转换,迁移完成后旧配置将被清除,请在迁移前在数据库中备份好旧配置。
|
||||
</p>
|
||||
</Modal>
|
||||
|
||||
{/* API信息管理 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsAPIInfo options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
|
||||
{/* 系统公告管理 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsAnnouncements options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
|
||||
{/* 常见问答管理 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsFAQ options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
|
||||
{/* Uptime Kuma 监控设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsUptimeKuma options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
</Spin>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
|
||||
import { Card, Spin } from '@douyinfe/semi-ui';
|
||||
import SettingsGeneral from '../../pages/Setting/Operation/SettingsGeneral.js';
|
||||
import SettingsDrawing from '../../pages/Setting/Operation/SettingsDrawing.js';
|
||||
import SettingsSensitiveWords from '../../pages/Setting/Operation/SettingsSensitiveWords.js';
|
||||
@@ -7,60 +7,58 @@ import SettingsLog from '../../pages/Setting/Operation/SettingsLog.js';
|
||||
import SettingsDataDashboard from '../../pages/Setting/Operation/SettingsDataDashboard.js';
|
||||
import SettingsMonitoring from '../../pages/Setting/Operation/SettingsMonitoring.js';
|
||||
import SettingsCreditLimit from '../../pages/Setting/Operation/SettingsCreditLimit.js';
|
||||
import ModelSettingsVisualEditor from '../../pages/Setting/Operation/ModelSettingsVisualEditor.js';
|
||||
import GroupRatioSettings from '../../pages/Setting/Operation/GroupRatioSettings.js';
|
||||
import ModelRatioSettings from '../../pages/Setting/Operation/ModelRatioSettings.js';
|
||||
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
import SettingsChats from '../../pages/Setting/Operation/SettingsChats.js';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ModelRatioNotSetEditor from '../../pages/Setting/Operation/ModelRationNotSetEditor.js';
|
||||
import { API, showError } from '../../helpers';
|
||||
|
||||
const OperationSetting = () => {
|
||||
const { t } = useTranslation();
|
||||
let [inputs, setInputs] = useState({
|
||||
/* 额度相关 */
|
||||
QuotaForNewUser: 0,
|
||||
PreConsumedQuota: 0,
|
||||
QuotaForInviter: 0,
|
||||
QuotaForInvitee: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
PreConsumedQuota: 0,
|
||||
StreamCacheQueueLength: 0,
|
||||
ModelRatio: '',
|
||||
CacheRatio: '',
|
||||
CompletionRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
UserUsableGroups: '',
|
||||
|
||||
/* 通用设置 */
|
||||
TopUpLink: '',
|
||||
'general_setting.docs_link': '',
|
||||
// ChatLink2: '', // 添加的新状态变量
|
||||
QuotaPerUnit: 0,
|
||||
AutomaticDisableChannelEnabled: false,
|
||||
AutomaticEnableChannelEnabled: false,
|
||||
ChannelDisableThreshold: 0,
|
||||
LogConsumeEnabled: false,
|
||||
RetryTimes: 0,
|
||||
DisplayInCurrencyEnabled: false,
|
||||
DisplayTokenStatEnabled: false,
|
||||
CheckSensitiveEnabled: false,
|
||||
CheckSensitiveOnPromptEnabled: false,
|
||||
CheckSensitiveOnCompletionEnabled: '',
|
||||
StopOnSensitiveEnabled: '',
|
||||
SensitiveWords: '',
|
||||
DefaultCollapseSidebar: false,
|
||||
DemoSiteEnabled: false,
|
||||
SelfUseModeEnabled: false,
|
||||
|
||||
/* 绘图设置 */
|
||||
DrawingEnabled: false,
|
||||
MjNotifyEnabled: false,
|
||||
MjAccountFilterEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
DrawingEnabled: false,
|
||||
|
||||
/* 敏感词设置 */
|
||||
CheckSensitiveEnabled: false,
|
||||
CheckSensitiveOnPromptEnabled: false,
|
||||
SensitiveWords: '',
|
||||
|
||||
/* 日志设置 */
|
||||
LogConsumeEnabled: false,
|
||||
|
||||
/* 数据看板 */
|
||||
DataExportEnabled: false,
|
||||
DataExportDefaultTime: 'hour',
|
||||
DataExportInterval: 5,
|
||||
DefaultCollapseSidebar: false, // 默认折叠侧边栏
|
||||
RetryTimes: 0,
|
||||
Chats: '[]',
|
||||
DemoSiteEnabled: false,
|
||||
SelfUseModeEnabled: false,
|
||||
|
||||
/* 监控设置 */
|
||||
ChannelDisableThreshold: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
AutomaticDisableChannelEnabled: false,
|
||||
AutomaticEnableChannelEnabled: false,
|
||||
AutomaticDisableKeywords: '',
|
||||
|
||||
/* 聊天设置 */
|
||||
Chats: '[]',
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
@@ -71,16 +69,6 @@ const OperationSetting = () => {
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (
|
||||
item.key === 'ModelRatio' ||
|
||||
item.key === 'GroupRatio' ||
|
||||
item.key === 'UserUsableGroups' ||
|
||||
item.key === 'CompletionRatio' ||
|
||||
item.key === 'ModelPrice' ||
|
||||
item.key === 'CacheRatio'
|
||||
) {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
if (
|
||||
item.key.endsWith('Enabled') ||
|
||||
['DefaultCollapseSidebar'].includes(item.key)
|
||||
@@ -147,24 +135,6 @@ const OperationSetting = () => {
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsChats options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 分组倍率设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<GroupRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 合并模型倍率设置和可视化倍率设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<Tabs type='line'>
|
||||
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
|
||||
<ModelRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('可视化倍率设置')} itemKey='visual'>
|
||||
<ModelSettingsVisualEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('未设置倍率模型')} itemKey='unset_models'>
|
||||
<ModelRatioNotSetEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
</Tabs>
|
||||
</Card>
|
||||
</Spin>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -103,6 +103,7 @@ const PersonalSetting = () => {
|
||||
webhookSecret: '',
|
||||
notificationEmail: '',
|
||||
acceptUnsetModelRatioModel: false,
|
||||
recordIpLog: false,
|
||||
});
|
||||
const [modelsLoading, setModelsLoading] = useState(true);
|
||||
const [showWebhookDocs, setShowWebhookDocs] = useState(true);
|
||||
@@ -147,6 +148,7 @@ const PersonalSetting = () => {
|
||||
notificationEmail: settings.notification_email || '',
|
||||
acceptUnsetModelRatioModel:
|
||||
settings.accept_unset_model_ratio_model || false,
|
||||
recordIpLog: settings.record_ip_log || false,
|
||||
});
|
||||
}
|
||||
}, [userState?.user?.setting]);
|
||||
@@ -346,7 +348,7 @@ const PersonalSetting = () => {
|
||||
const handleNotificationSettingChange = (type, value) => {
|
||||
setNotificationSettings((prev) => ({
|
||||
...prev,
|
||||
[type]: value.target ? value.target.value : value, // 处理 Radio 事件对象
|
||||
[type]: value.target ? value.target.value !== undefined ? value.target.value : value.target.checked : value, // handle checkbox properly
|
||||
}));
|
||||
};
|
||||
|
||||
@@ -362,16 +364,17 @@ const PersonalSetting = () => {
|
||||
notification_email: notificationSettings.notificationEmail,
|
||||
accept_unset_model_ratio_model:
|
||||
notificationSettings.acceptUnsetModelRatioModel,
|
||||
record_ip_log: notificationSettings.recordIpLog,
|
||||
});
|
||||
|
||||
if (res.data.success) {
|
||||
showSuccess(t('通知设置已更新'));
|
||||
showSuccess(t('设置保存成功'));
|
||||
await getUserData();
|
||||
} else {
|
||||
showError(res.data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showError(t('更新通知设置失败'));
|
||||
showError(t('设置保存失败'));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1063,7 +1066,7 @@ const PersonalSetting = () => {
|
||||
tab={
|
||||
<div className="flex items-center">
|
||||
<Bell size={16} className="mr-2" />
|
||||
{t('通知设置')}
|
||||
{t('其他设置')}
|
||||
</div>
|
||||
}
|
||||
itemKey='notification'
|
||||
@@ -1228,28 +1231,68 @@ const PersonalSetting = () => {
|
||||
<TabPane
|
||||
tab={t('价格设置')}
|
||||
itemKey='price'
|
||||
>
|
||||
<div className="py-4">
|
||||
<div className="space-y-4">
|
||||
{/* 接受未设置价格模型 */}
|
||||
<div className="bg-white rounded-xl">
|
||||
<div className="flex items-start">
|
||||
<div className="w-10 h-10 rounded-full bg-slate-100 flex items-center justify-center mt-1">
|
||||
<Shield size={20} className="text-slate-600" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<Typography.Text strong className="block mb-2">
|
||||
{t('接受未设置价格模型')}
|
||||
</Typography.Text>
|
||||
<div className="text-gray-500 text-sm">
|
||||
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
|
||||
</div>
|
||||
</div>
|
||||
<Checkbox
|
||||
checked={notificationSettings.acceptUnsetModelRatioModel}
|
||||
onChange={(e) =>
|
||||
handleNotificationSettingChange(
|
||||
'acceptUnsetModelRatioModel',
|
||||
e.target.checked,
|
||||
)
|
||||
}
|
||||
className="ml-4"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</TabPane>
|
||||
|
||||
<TabPane
|
||||
tab={t('IP记录')}
|
||||
itemKey='ip'
|
||||
>
|
||||
<div className="py-4">
|
||||
<div className="bg-white rounded-xl">
|
||||
<div className="flex items-start">
|
||||
<div className="w-10 h-10 rounded-full bg-slate-100 flex items-center justify-center mt-1">
|
||||
<Shield size={20} className="text-slate-600" />
|
||||
<ShieldCheck size={20} className="text-slate-600" />
|
||||
</div>
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<Typography.Text strong className="block mb-2">
|
||||
{t('接受未设置价格模型')}
|
||||
{t('记录请求与错误日志 IP')}
|
||||
</Typography.Text>
|
||||
<div className="text-gray-500 text-sm">
|
||||
{t('当模型没有设置价格时仍接受调用,仅当您信任该网站时使用,可能会产生高额费用')}
|
||||
{t('开启后,仅“消费”和“错误”日志将记录您的客户端 IP 地址')}
|
||||
</div>
|
||||
</div>
|
||||
<Checkbox
|
||||
checked={notificationSettings.acceptUnsetModelRatioModel}
|
||||
checked={notificationSettings.recordIpLog}
|
||||
onChange={(e) =>
|
||||
handleNotificationSettingChange(
|
||||
'acceptUnsetModelRatioModel',
|
||||
'recordIpLog',
|
||||
e.target.checked,
|
||||
)
|
||||
}
|
||||
|
||||
109
web/src/components/settings/RatioSetting.js
Normal file
109
web/src/components/settings/RatioSetting.js
Normal file
@@ -0,0 +1,109 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import GroupRatioSettings from '../../pages/Setting/Ratio/GroupRatioSettings.js';
|
||||
import ModelRatioSettings from '../../pages/Setting/Ratio/ModelRatioSettings.js';
|
||||
import ModelSettingsVisualEditor from '../../pages/Setting/Ratio/ModelSettingsVisualEditor.js';
|
||||
import ModelRatioNotSetEditor from '../../pages/Setting/Ratio/ModelRationNotSetEditor.js';
|
||||
|
||||
import { API, showError } from '../../helpers';
|
||||
|
||||
const RatioSetting = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
let [inputs, setInputs] = useState({
|
||||
ModelPrice: '',
|
||||
ModelRatio: '',
|
||||
CacheRatio: '',
|
||||
CompletionRatio: '',
|
||||
GroupRatio: '',
|
||||
GroupGroupRatio: '',
|
||||
AutoGroups: '',
|
||||
DefaultUseAutoGroup: false,
|
||||
UserUsableGroups: '',
|
||||
});
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (
|
||||
item.key === 'ModelRatio' ||
|
||||
item.key === 'GroupRatio' ||
|
||||
item.key === 'GroupGroupRatio' ||
|
||||
item.key === 'AutoGroups' ||
|
||||
item.key === 'UserUsableGroups' ||
|
||||
item.key === 'CompletionRatio' ||
|
||||
item.key === 'ModelPrice' ||
|
||||
item.key === 'CacheRatio'
|
||||
) {
|
||||
try {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
} catch (e) {
|
||||
// 如果后端返回的不是合法 JSON,直接展示
|
||||
}
|
||||
}
|
||||
if (['DefaultUseAutoGroup'].includes(item.key)) {
|
||||
newInputs[item.key] = item.value === 'true' ? true : false;
|
||||
} else {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
});
|
||||
setInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const onRefresh = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
await getOptions();
|
||||
} catch (error) {
|
||||
showError('刷新失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
onRefresh();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 分组倍率设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<GroupRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 模型倍率设置以及可视化编辑器 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<Tabs type='line'>
|
||||
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
|
||||
<ModelRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('可视化倍率设置')} itemKey='visual'>
|
||||
<ModelSettingsVisualEditor
|
||||
options={inputs}
|
||||
refresh={onRefresh}
|
||||
/>
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('未设置倍率模型')} itemKey='unset_models'>
|
||||
<ModelRatioNotSetEditor
|
||||
options={inputs}
|
||||
refresh={onRefresh}
|
||||
/>
|
||||
</Tabs.TabPane>
|
||||
</Tabs>
|
||||
</Card>
|
||||
</Spin>
|
||||
);
|
||||
};
|
||||
|
||||
export default RatioSetting;
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
removeTrailingSlash,
|
||||
showError,
|
||||
showSuccess,
|
||||
verifyJSON
|
||||
verifyJSON,
|
||||
} from '../../helpers';
|
||||
import axios from 'axios';
|
||||
|
||||
@@ -73,6 +73,7 @@ const SystemSetting = () => {
|
||||
LinuxDOOAuthEnabled: '',
|
||||
LinuxDOClientId: '',
|
||||
LinuxDOClientSecret: '',
|
||||
PayMethods: '',
|
||||
});
|
||||
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
@@ -230,6 +231,12 @@ const SystemSetting = () => {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (originInputs['PayMethods'] !== inputs.PayMethods) {
|
||||
if (!verifyJSON(inputs.PayMethods)) {
|
||||
showError('充值方式设置不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const options = [
|
||||
{ key: 'PayAddress', value: removeTrailingSlash(inputs.PayAddress) },
|
||||
@@ -256,6 +263,9 @@ const SystemSetting = () => {
|
||||
if (originInputs['TopupGroupRatio'] !== inputs.TopupGroupRatio) {
|
||||
options.push({ key: 'TopupGroupRatio', value: inputs.TopupGroupRatio });
|
||||
}
|
||||
if (originInputs['PayMethods'] !== inputs.PayMethods) {
|
||||
options.push({ key: 'PayMethods', value: inputs.PayMethods });
|
||||
}
|
||||
|
||||
await updateOptions(options);
|
||||
};
|
||||
@@ -658,6 +668,12 @@ const SystemSetting = () => {
|
||||
placeholder='为一个 JSON 文本,键为组名称,值为倍率'
|
||||
autosize
|
||||
/>
|
||||
<Form.TextArea
|
||||
field='PayMethods'
|
||||
label='充值方式设置'
|
||||
placeholder='为一个 JSON 文本'
|
||||
autosize
|
||||
/>
|
||||
<Button onClick={submitPayAddress}>更新支付设置</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import React, { useEffect, useState, useMemo, useRef } from 'react';
|
||||
import {
|
||||
API,
|
||||
showError,
|
||||
@@ -16,11 +16,6 @@ import {
|
||||
XCircle,
|
||||
AlertCircle,
|
||||
HelpCircle,
|
||||
TestTube,
|
||||
Zap,
|
||||
Timer,
|
||||
Clock,
|
||||
AlertTriangle,
|
||||
Coins,
|
||||
Tags
|
||||
} from 'lucide-react';
|
||||
@@ -41,8 +36,11 @@ import {
|
||||
Tag,
|
||||
Tooltip,
|
||||
Typography,
|
||||
Checkbox,
|
||||
Card,
|
||||
Form
|
||||
Form,
|
||||
Tabs,
|
||||
TabPane
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
@@ -51,7 +49,6 @@ import {
|
||||
import EditChannel from '../../pages/Channel/EditChannel.js';
|
||||
import {
|
||||
IconTreeTriangleDown,
|
||||
IconFilter,
|
||||
IconPlus,
|
||||
IconRefresh,
|
||||
IconSetting,
|
||||
@@ -141,48 +138,139 @@ const ChannelsTable = () => {
|
||||
time = time.toFixed(2) + t(' 秒');
|
||||
if (responseTime === 0) {
|
||||
return (
|
||||
<Tag size='large' color='grey' shape='circle' prefixIcon={<TestTube size={14} />}>
|
||||
<Tag size='large' color='grey' shape='circle'>
|
||||
{t('未测试')}
|
||||
</Tag>
|
||||
);
|
||||
} else if (responseTime <= 1000) {
|
||||
return (
|
||||
<Tag size='large' color='green' shape='circle' prefixIcon={<Zap size={14} />}>
|
||||
<Tag size='large' color='green' shape='circle'>
|
||||
{time}
|
||||
</Tag>
|
||||
);
|
||||
} else if (responseTime <= 3000) {
|
||||
return (
|
||||
<Tag size='large' color='lime' shape='circle' prefixIcon={<Timer size={14} />}>
|
||||
<Tag size='large' color='lime' shape='circle'>
|
||||
{time}
|
||||
</Tag>
|
||||
);
|
||||
} else if (responseTime <= 5000) {
|
||||
return (
|
||||
<Tag size='large' color='yellow' shape='circle' prefixIcon={<Clock size={14} />}>
|
||||
<Tag size='large' color='yellow' shape='circle'>
|
||||
{time}
|
||||
</Tag>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<Tag size='large' color='red' shape='circle' prefixIcon={<AlertTriangle size={14} />}>
|
||||
<Tag size='large' color='red' shape='circle'>
|
||||
{time}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Define all columns
|
||||
const columns = [
|
||||
// Define column keys for selection
|
||||
const COLUMN_KEYS = {
|
||||
ID: 'id',
|
||||
NAME: 'name',
|
||||
GROUP: 'group',
|
||||
TYPE: 'type',
|
||||
STATUS: 'status',
|
||||
RESPONSE_TIME: 'response_time',
|
||||
BALANCE: 'balance',
|
||||
PRIORITY: 'priority',
|
||||
WEIGHT: 'weight',
|
||||
OPERATE: 'operate',
|
||||
};
|
||||
|
||||
// State for column visibility
|
||||
const [visibleColumns, setVisibleColumns] = useState({});
|
||||
const [showColumnSelector, setShowColumnSelector] = useState(false);
|
||||
|
||||
// Load saved column preferences from localStorage
|
||||
useEffect(() => {
|
||||
const savedColumns = localStorage.getItem('channels-table-columns');
|
||||
if (savedColumns) {
|
||||
try {
|
||||
const parsed = JSON.parse(savedColumns);
|
||||
// Make sure all columns are accounted for
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
const merged = { ...defaults, ...parsed };
|
||||
setVisibleColumns(merged);
|
||||
} catch (e) {
|
||||
console.error('Failed to parse saved column preferences', e);
|
||||
initDefaultColumns();
|
||||
}
|
||||
} else {
|
||||
initDefaultColumns();
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Update table when column visibility changes
|
||||
useEffect(() => {
|
||||
if (Object.keys(visibleColumns).length > 0) {
|
||||
// Save to localStorage
|
||||
localStorage.setItem(
|
||||
'channels-table-columns',
|
||||
JSON.stringify(visibleColumns),
|
||||
);
|
||||
}
|
||||
}, [visibleColumns]);
|
||||
|
||||
// Get default column visibility
|
||||
const getDefaultColumnVisibility = () => {
|
||||
return {
|
||||
[COLUMN_KEYS.ID]: true,
|
||||
[COLUMN_KEYS.NAME]: true,
|
||||
[COLUMN_KEYS.GROUP]: true,
|
||||
[COLUMN_KEYS.TYPE]: true,
|
||||
[COLUMN_KEYS.STATUS]: true,
|
||||
[COLUMN_KEYS.RESPONSE_TIME]: true,
|
||||
[COLUMN_KEYS.BALANCE]: true,
|
||||
[COLUMN_KEYS.PRIORITY]: true,
|
||||
[COLUMN_KEYS.WEIGHT]: true,
|
||||
[COLUMN_KEYS.OPERATE]: true,
|
||||
};
|
||||
};
|
||||
|
||||
// Initialize default column visibility
|
||||
const initDefaultColumns = () => {
|
||||
const defaults = getDefaultColumnVisibility();
|
||||
setVisibleColumns(defaults);
|
||||
};
|
||||
|
||||
// Handle column visibility change
|
||||
const handleColumnVisibilityChange = (columnKey, checked) => {
|
||||
const updatedColumns = { ...visibleColumns, [columnKey]: checked };
|
||||
setVisibleColumns(updatedColumns);
|
||||
};
|
||||
|
||||
// Handle "Select All" checkbox
|
||||
const handleSelectAll = (checked) => {
|
||||
const allKeys = Object.keys(COLUMN_KEYS).map((key) => COLUMN_KEYS[key]);
|
||||
const updatedColumns = {};
|
||||
|
||||
allKeys.forEach((key) => {
|
||||
updatedColumns[key] = checked;
|
||||
});
|
||||
|
||||
setVisibleColumns(updatedColumns);
|
||||
};
|
||||
|
||||
// Define all columns with keys
|
||||
const allColumns = [
|
||||
{
|
||||
key: COLUMN_KEYS.ID,
|
||||
title: t('ID'),
|
||||
dataIndex: 'id',
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.NAME,
|
||||
title: t('名称'),
|
||||
dataIndex: 'name',
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.GROUP,
|
||||
title: t('分组'),
|
||||
dataIndex: 'group',
|
||||
render: (text, record, index) => (
|
||||
@@ -201,6 +289,7 @@ const ChannelsTable = () => {
|
||||
),
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.TYPE,
|
||||
title: t('类型'),
|
||||
dataIndex: 'type',
|
||||
render: (text, record, index) => {
|
||||
@@ -212,6 +301,7 @@ const ChannelsTable = () => {
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.STATUS,
|
||||
title: t('状态'),
|
||||
dataIndex: 'status',
|
||||
render: (text, record, index) => {
|
||||
@@ -237,6 +327,7 @@ const ChannelsTable = () => {
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.RESPONSE_TIME,
|
||||
title: t('响应时间'),
|
||||
dataIndex: 'response_time',
|
||||
render: (text, record, index) => (
|
||||
@@ -244,6 +335,7 @@ const ChannelsTable = () => {
|
||||
),
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.BALANCE,
|
||||
title: t('已用/剩余'),
|
||||
dataIndex: 'expired_time',
|
||||
render: (text, record, index) => {
|
||||
@@ -283,6 +375,7 @@ const ChannelsTable = () => {
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.PRIORITY,
|
||||
title: t('优先级'),
|
||||
dataIndex: 'priority',
|
||||
render: (text, record, index) => {
|
||||
@@ -334,6 +427,7 @@ const ChannelsTable = () => {
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.WEIGHT,
|
||||
title: t('权重'),
|
||||
dataIndex: 'weight',
|
||||
render: (text, record, index) => {
|
||||
@@ -385,6 +479,7 @@ const ChannelsTable = () => {
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.OPERATE,
|
||||
title: '',
|
||||
dataIndex: 'operate',
|
||||
fixed: 'right',
|
||||
@@ -584,17 +679,97 @@ const ChannelsTable = () => {
|
||||
const [isBatchTesting, setIsBatchTesting] = useState(false);
|
||||
const [testQueue, setTestQueue] = useState([]);
|
||||
const [isProcessingQueue, setIsProcessingQueue] = useState(false);
|
||||
|
||||
// Form API 引用
|
||||
const [activeTypeKey, setActiveTypeKey] = useState('all');
|
||||
const [typeCounts, setTypeCounts] = useState({});
|
||||
const requestCounter = useRef(0);
|
||||
const [formApi, setFormApi] = useState(null);
|
||||
|
||||
// Form 初始值
|
||||
const formInitValues = {
|
||||
searchKeyword: '',
|
||||
searchGroup: '',
|
||||
searchModel: '',
|
||||
};
|
||||
|
||||
// Filter columns based on visibility settings
|
||||
const getVisibleColumns = () => {
|
||||
return allColumns.filter((column) => visibleColumns[column.key]);
|
||||
};
|
||||
|
||||
// Column selector modal
|
||||
const renderColumnSelector = () => {
|
||||
return (
|
||||
<Modal
|
||||
title={t('列设置')}
|
||||
visible={showColumnSelector}
|
||||
onCancel={() => setShowColumnSelector(false)}
|
||||
footer={
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
theme="light"
|
||||
onClick={() => initDefaultColumns()}
|
||||
className="!rounded-full"
|
||||
>
|
||||
{t('重置')}
|
||||
</Button>
|
||||
<Button
|
||||
theme="light"
|
||||
onClick={() => setShowColumnSelector(false)}
|
||||
className="!rounded-full"
|
||||
>
|
||||
{t('取消')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
onClick={() => setShowColumnSelector(false)}
|
||||
className="!rounded-full"
|
||||
>
|
||||
{t('确定')}
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div style={{ marginBottom: 20 }}>
|
||||
<Checkbox
|
||||
checked={Object.values(visibleColumns).every((v) => v === true)}
|
||||
indeterminate={
|
||||
Object.values(visibleColumns).some((v) => v === true) &&
|
||||
!Object.values(visibleColumns).every((v) => v === true)
|
||||
}
|
||||
onChange={(e) => handleSelectAll(e.target.checked)}
|
||||
>
|
||||
{t('全选')}
|
||||
</Checkbox>
|
||||
</div>
|
||||
<div
|
||||
className="flex flex-wrap max-h-96 overflow-y-auto rounded-lg p-4"
|
||||
style={{ border: '1px solid var(--semi-color-border)' }}
|
||||
>
|
||||
{allColumns.map((column) => {
|
||||
// Skip columns without title
|
||||
if (!column.title) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
key={column.key}
|
||||
className="w-1/2 mb-4 pr-2"
|
||||
>
|
||||
<Checkbox
|
||||
checked={!!visibleColumns[column.key]}
|
||||
onChange={(e) =>
|
||||
handleColumnVisibilityChange(column.key, e.target.checked)
|
||||
}
|
||||
>
|
||||
{column.title}
|
||||
</Checkbox>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
const removeRecord = (record) => {
|
||||
let newDataSource = [...channels];
|
||||
if (record.id != null) {
|
||||
@@ -686,32 +861,28 @@ const ChannelsTable = () => {
|
||||
tagChannelDates.response_time = tagChannelDates.response_time / 2;
|
||||
}
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setChannels(channelDates);
|
||||
if (channelDates.length >= pageSize) {
|
||||
setChannelCount(channelDates.length + pageSize);
|
||||
} else {
|
||||
setChannelCount(channelDates.length);
|
||||
}
|
||||
};
|
||||
|
||||
const loadChannels = async (startIdx, pageSize, idSort, enableTagMode) => {
|
||||
const loadChannels = async (page, pageSize, idSort, enableTagMode, typeKey = activeTypeKey) => {
|
||||
const reqId = ++requestCounter.current; // 记录当前请求序号
|
||||
setLoading(true);
|
||||
const typeParam = (!enableTagMode && typeKey !== 'all') ? `&type=${typeKey}` : '';
|
||||
const res = await API.get(
|
||||
`/api/channel/?p=${startIdx}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}`,
|
||||
`/api/channel/?p=${page}&page_size=${pageSize}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}`,
|
||||
);
|
||||
if (res === undefined) {
|
||||
if (res === undefined || reqId !== requestCounter.current) {
|
||||
return;
|
||||
}
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setChannelFormat(data, enableTagMode);
|
||||
} else {
|
||||
let newChannels = [...channels];
|
||||
newChannels.splice(startIdx * pageSize, data.length, ...data);
|
||||
setChannelFormat(newChannels, enableTagMode);
|
||||
const { items, total, type_counts } = data;
|
||||
if (type_counts) {
|
||||
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
|
||||
setTypeCounts({ ...type_counts, all: sumAll });
|
||||
}
|
||||
setChannelFormat(items, enableTagMode);
|
||||
setChannelCount(total);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
@@ -724,7 +895,6 @@ const ChannelsTable = () => {
|
||||
channelToCopy.created_time = null;
|
||||
channelToCopy.balance = 0;
|
||||
channelToCopy.used_quota = 0;
|
||||
// 删除可能导致类型不匹配的字段
|
||||
delete channelToCopy.test_time;
|
||||
delete channelToCopy.response_time;
|
||||
if (!channelToCopy) {
|
||||
@@ -748,7 +918,7 @@ const ChannelsTable = () => {
|
||||
const refresh = async () => {
|
||||
const { searchKeyword, searchGroup, searchModel } = getFormValues();
|
||||
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
|
||||
await loadChannels(activePage - 1, pageSize, idSort, enableTagMode);
|
||||
await loadChannels(activePage, pageSize, idSort, enableTagMode);
|
||||
} else {
|
||||
await searchChannels(enableTagMode);
|
||||
}
|
||||
@@ -765,7 +935,7 @@ const ChannelsTable = () => {
|
||||
setPageSize(localPageSize);
|
||||
setEnableTagMode(localEnableTagMode);
|
||||
setEnableBatchDelete(localEnableBatchDelete);
|
||||
loadChannels(0, localPageSize, localIdSort, localEnableTagMode)
|
||||
loadChannels(1, localPageSize, localIdSort, localEnableTagMode)
|
||||
.then()
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
@@ -873,16 +1043,19 @@ const ChannelsTable = () => {
|
||||
try {
|
||||
if (searchKeyword === '' && searchGroup === '' && searchModel === '') {
|
||||
await loadChannels(activePage - 1, pageSize, idSort, enableTagMode);
|
||||
// setActivePage(1);
|
||||
return;
|
||||
}
|
||||
|
||||
const typeParam = (!enableTagMode && activeTypeKey !== 'all') ? `&type=${activeTypeKey}` : '';
|
||||
const res = await API.get(
|
||||
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}&tag_mode=${enableTagMode}`,
|
||||
`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}&id_sort=${idSort}&tag_mode=${enableTagMode}${typeParam}`,
|
||||
);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
setChannelFormat(data, enableTagMode);
|
||||
const { items = [], type_counts = {} } = data;
|
||||
const sumAll = Object.values(type_counts).reduce((acc, v) => acc + v, 0);
|
||||
setTypeCounts({ ...type_counts, all: sumAll });
|
||||
setChannelFormat(items, enableTagMode);
|
||||
setActivePage(1);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -1012,24 +1185,105 @@ const ChannelsTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
let pageData = channels.slice(
|
||||
(activePage - 1) * pageSize,
|
||||
activePage * pageSize,
|
||||
);
|
||||
const channelTypeCounts = useMemo(() => {
|
||||
if (Object.keys(typeCounts).length > 0) return typeCounts;
|
||||
// fallback 本地计算
|
||||
const counts = { all: channels.length };
|
||||
channels.forEach((channel) => {
|
||||
const collect = (ch) => {
|
||||
const type = ch.type;
|
||||
counts[type] = (counts[type] || 0) + 1;
|
||||
};
|
||||
if (channel.children !== undefined) {
|
||||
channel.children.forEach(collect);
|
||||
} else {
|
||||
collect(channel);
|
||||
}
|
||||
});
|
||||
return counts;
|
||||
}, [typeCounts, channels]);
|
||||
|
||||
const availableTypeKeys = useMemo(() => {
|
||||
const keys = ['all'];
|
||||
Object.entries(channelTypeCounts).forEach(([k, v]) => {
|
||||
if (k !== 'all' && v > 0) keys.push(String(k));
|
||||
});
|
||||
return keys;
|
||||
}, [channelTypeCounts]);
|
||||
|
||||
const renderTypeTabs = () => {
|
||||
if (enableTagMode) return null;
|
||||
|
||||
return (
|
||||
<Tabs
|
||||
activeKey={activeTypeKey}
|
||||
type="card"
|
||||
collapsible
|
||||
onChange={(key) => {
|
||||
setActiveTypeKey(key);
|
||||
setActivePage(1);
|
||||
loadChannels(1, pageSize, idSort, enableTagMode, key);
|
||||
}}
|
||||
className="mb-4"
|
||||
>
|
||||
<TabPane
|
||||
itemKey="all"
|
||||
tab={
|
||||
<span className="flex items-center gap-2">
|
||||
{t('全部')}
|
||||
<Tag color={activeTypeKey === 'all' ? 'red' : 'grey'} size='small' shape='circle'>
|
||||
{channelTypeCounts['all'] || 0}
|
||||
</Tag>
|
||||
</span>
|
||||
}
|
||||
/>
|
||||
|
||||
{CHANNEL_OPTIONS.filter((opt) => availableTypeKeys.includes(String(opt.value))).map((option) => {
|
||||
const key = String(option.value);
|
||||
const count = channelTypeCounts[option.value] || 0;
|
||||
return (
|
||||
<TabPane
|
||||
key={key}
|
||||
itemKey={key}
|
||||
tab={
|
||||
<span className="flex items-center gap-2">
|
||||
{getChannelIcon(option.value)}
|
||||
{option.label}
|
||||
<Tag color={activeTypeKey === key ? 'red' : 'grey'} size='small' shape='circle'>
|
||||
{count}
|
||||
</Tag>
|
||||
</span>
|
||||
}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Tabs>
|
||||
);
|
||||
};
|
||||
|
||||
let pageData = channels;
|
||||
if (activeTypeKey !== 'all') {
|
||||
const typeVal = parseInt(activeTypeKey);
|
||||
if (!isNaN(typeVal)) {
|
||||
pageData = pageData.filter((ch) => {
|
||||
if (ch.children !== undefined) {
|
||||
return ch.children.some((c) => c.type === typeVal);
|
||||
}
|
||||
return ch.type === typeVal;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
setActivePage(page);
|
||||
if (page === Math.ceil(channels.length / pageSize) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
loadChannels(page - 1, pageSize, idSort, enableTagMode).then((r) => { });
|
||||
}
|
||||
loadChannels(page, pageSize, idSort, enableTagMode).then(() => { });
|
||||
};
|
||||
|
||||
const handlePageSizeChange = async (size) => {
|
||||
localStorage.setItem('page-size', size + '');
|
||||
setPageSize(size);
|
||||
setActivePage(1);
|
||||
loadChannels(0, size, idSort, enableTagMode)
|
||||
loadChannels(1, size, idSort, enableTagMode)
|
||||
.then()
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
@@ -1039,8 +1293,6 @@ const ChannelsTable = () => {
|
||||
const fetchGroups = async () => {
|
||||
try {
|
||||
let res = await API.get(`/api/group/`);
|
||||
// add 'all' option
|
||||
// res.data.data.unshift('all');
|
||||
if (res === undefined) {
|
||||
return;
|
||||
}
|
||||
@@ -1212,6 +1464,7 @@ const ChannelsTable = () => {
|
||||
|
||||
const renderHeader = () => (
|
||||
<div className="flex flex-col w-full">
|
||||
{renderTypeTabs()}
|
||||
<div className="flex flex-col md:flex-row justify-between gap-4">
|
||||
<div className="flex flex-wrap md:flex-nowrap items-center gap-2 w-full md:w-auto order-2 md:order-1">
|
||||
<Button
|
||||
@@ -1335,7 +1588,7 @@ const ChannelsTable = () => {
|
||||
onChange={(v) => {
|
||||
localStorage.setItem('id-sort', v + '');
|
||||
setIdSort(v);
|
||||
loadChannels(0, pageSize, v, enableTagMode);
|
||||
loadChannels(activePage, pageSize, v, enableTagMode);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
@@ -1362,7 +1615,8 @@ const ChannelsTable = () => {
|
||||
onChange={(v) => {
|
||||
localStorage.setItem('enable-tag-mode', v + '');
|
||||
setEnableTagMode(v);
|
||||
loadChannels(0, pageSize, idSort, v);
|
||||
setActivePage(1);
|
||||
loadChannels(1, pageSize, idSort, v);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
@@ -1397,6 +1651,16 @@ const ChannelsTable = () => {
|
||||
>
|
||||
{t('刷新')}
|
||||
</Button>
|
||||
|
||||
<Button
|
||||
theme='light'
|
||||
type='tertiary'
|
||||
icon={<IconSetting />}
|
||||
onClick={() => setShowColumnSelector(true)}
|
||||
className="!rounded-full w-full md:w-auto"
|
||||
>
|
||||
{t('列设置')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col md:flex-row items-center gap-4 w-full md:w-auto order-1 md:order-2">
|
||||
@@ -1424,7 +1688,7 @@ const ChannelsTable = () => {
|
||||
<div className="w-full md:w-48">
|
||||
<Form.Input
|
||||
field="searchModel"
|
||||
prefix={<IconFilter />}
|
||||
prefix={<IconSearch />}
|
||||
placeholder={t('模型关键字')}
|
||||
className="!rounded-full"
|
||||
showClear
|
||||
@@ -1481,6 +1745,7 @@ const ChannelsTable = () => {
|
||||
|
||||
return (
|
||||
<>
|
||||
{renderColumnSelector()}
|
||||
<EditTagModal
|
||||
visible={showEditTag}
|
||||
tag={editingTag}
|
||||
@@ -1501,7 +1766,7 @@ const ChannelsTable = () => {
|
||||
bordered={false}
|
||||
>
|
||||
<Table
|
||||
columns={columns}
|
||||
columns={getVisibleColumns()}
|
||||
dataSource={pageData}
|
||||
scroll={{ x: 'max-content' }}
|
||||
pagination={{
|
||||
@@ -1513,7 +1778,7 @@ const ChannelsTable = () => {
|
||||
formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
|
||||
start: page.currentStart,
|
||||
end: page.currentEnd,
|
||||
total: channels.length,
|
||||
total: channelCount,
|
||||
}),
|
||||
onPageSizeChange: (size) => {
|
||||
handlePageSizeChange(size);
|
||||
@@ -1632,7 +1897,6 @@ const ChannelsTable = () => {
|
||||
</div>
|
||||
}
|
||||
maskClosable={!isBatchTesting}
|
||||
centered={true}
|
||||
className="!rounded-lg"
|
||||
size="large"
|
||||
>
|
||||
@@ -1733,7 +1997,6 @@ const ChannelsTable = () => {
|
||||
key: model
|
||||
}))}
|
||||
pagination={false}
|
||||
size="middle"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
renderQuota,
|
||||
stringToColor,
|
||||
getLogOther,
|
||||
renderModelTag
|
||||
renderModelTag,
|
||||
} from '../../helpers';
|
||||
|
||||
import {
|
||||
@@ -39,15 +39,16 @@ import {
|
||||
Card,
|
||||
Typography,
|
||||
Divider,
|
||||
Form
|
||||
Form,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark
|
||||
IllustrationNoResultDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { ITEMS_PER_PAGE } from '../../constants';
|
||||
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
|
||||
import { IconSetting, IconSearch, IconForward } from '@douyinfe/semi-icons';
|
||||
import { IconSetting, IconSearch, IconHelpCircle } from '@douyinfe/semi-icons';
|
||||
import { Route } from 'lucide-react';
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
@@ -191,7 +192,7 @@ const LogsTable = () => {
|
||||
if (!modelMapped) {
|
||||
return renderModelTag(record.model_name, {
|
||||
onClick: (event) => {
|
||||
copyText(event, record.model_name).then((r) => { });
|
||||
copyText(event, record.model_name).then((r) => {});
|
||||
},
|
||||
});
|
||||
} else {
|
||||
@@ -208,7 +209,7 @@ const LogsTable = () => {
|
||||
</Text>
|
||||
{renderModelTag(record.model_name, {
|
||||
onClick: (event) => {
|
||||
copyText(event, record.model_name).then((r) => { });
|
||||
copyText(event, record.model_name).then((r) => {});
|
||||
},
|
||||
})}
|
||||
</div>
|
||||
@@ -219,7 +220,7 @@ const LogsTable = () => {
|
||||
{renderModelTag(other.upstream_model_name, {
|
||||
onClick: (event) => {
|
||||
copyText(event, other.upstream_model_name).then(
|
||||
(r) => { },
|
||||
(r) => {},
|
||||
);
|
||||
},
|
||||
})}
|
||||
@@ -230,8 +231,13 @@ const LogsTable = () => {
|
||||
>
|
||||
{renderModelTag(record.model_name, {
|
||||
onClick: (event) => {
|
||||
copyText(event, record.model_name).then((r) => { });
|
||||
copyText(event, record.model_name).then((r) => {});
|
||||
},
|
||||
suffixIcon: (
|
||||
<Route
|
||||
style={{ width: '0.9em', height: '0.9em', opacity: 0.75 }}
|
||||
/>
|
||||
),
|
||||
})}
|
||||
</Popover>
|
||||
</Space>
|
||||
@@ -254,6 +260,7 @@ const LogsTable = () => {
|
||||
COMPLETION: 'completion',
|
||||
COST: 'cost',
|
||||
RETRY: 'retry',
|
||||
IP: 'ip',
|
||||
DETAILS: 'details',
|
||||
};
|
||||
|
||||
@@ -295,6 +302,7 @@ const LogsTable = () => {
|
||||
[COLUMN_KEYS.COMPLETION]: true,
|
||||
[COLUMN_KEYS.COST]: true,
|
||||
[COLUMN_KEYS.RETRY]: isAdminUser,
|
||||
[COLUMN_KEYS.IP]: true,
|
||||
[COLUMN_KEYS.DETAILS]: true,
|
||||
};
|
||||
};
|
||||
@@ -479,6 +487,9 @@ const LogsTable = () => {
|
||||
title: t('用时/首字'),
|
||||
dataIndex: 'use_time',
|
||||
render: (text, record, index) => {
|
||||
if (!(record.type === 2 || record.type === 5)) {
|
||||
return <></>;
|
||||
}
|
||||
if (record.is_stream) {
|
||||
let other = getLogOther(record.other);
|
||||
return (
|
||||
@@ -539,12 +550,45 @@ const LogsTable = () => {
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.IP,
|
||||
title: (
|
||||
<div className="flex items-center gap-1">
|
||||
{t('IP')}
|
||||
<Tooltip content={t('只有当用户设置开启IP记录时,才会进行请求和错误类型日志的IP记录')}>
|
||||
<IconHelpCircle className="text-gray-400 cursor-help" />
|
||||
</Tooltip>
|
||||
</div>
|
||||
),
|
||||
dataIndex: 'ip',
|
||||
render: (text, record, index) => {
|
||||
return (record.type === 2 || record.type === 5) && text ? (
|
||||
<Tooltip content={text}>
|
||||
<Tag
|
||||
color='orange'
|
||||
size='large'
|
||||
shape='circle'
|
||||
onClick={(event) => {
|
||||
copyText(event, text);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
) : (
|
||||
<></>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
key: COLUMN_KEYS.RETRY,
|
||||
title: t('重试'),
|
||||
dataIndex: 'retry',
|
||||
className: isAdmin() ? 'tableShow' : 'tableHiddle',
|
||||
render: (text, record, index) => {
|
||||
if (!(record.type === 2 || record.type === 5)) {
|
||||
return <></>;
|
||||
}
|
||||
let content = t('渠道') + `:${record.channel}`;
|
||||
if (record.other !== '') {
|
||||
let other = JSON.parse(record.other);
|
||||
@@ -592,21 +636,23 @@ const LogsTable = () => {
|
||||
}
|
||||
let content = other?.claude
|
||||
? renderClaudeModelPriceSimple(
|
||||
other.model_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other.cache_tokens || 0,
|
||||
other.cache_ratio || 1.0,
|
||||
other.cache_creation_tokens || 0,
|
||||
other.cache_creation_ratio || 1.0,
|
||||
)
|
||||
other.model_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
other.cache_tokens || 0,
|
||||
other.cache_ratio || 1.0,
|
||||
other.cache_creation_tokens || 0,
|
||||
other.cache_creation_ratio || 1.0,
|
||||
)
|
||||
: renderModelPriceSimple(
|
||||
other.model_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other.cache_tokens || 0,
|
||||
other.cache_ratio || 1.0,
|
||||
);
|
||||
other.model_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
other.cache_tokens || 0,
|
||||
other.cache_ratio || 1.0,
|
||||
);
|
||||
return (
|
||||
<Paragraph
|
||||
ellipsis={{
|
||||
@@ -736,7 +782,7 @@ const LogsTable = () => {
|
||||
group: '',
|
||||
dateRange: [
|
||||
timestamp2string(getTodayStartTimestamp()),
|
||||
timestamp2string(now.getTime() / 1000 + 3600)
|
||||
timestamp2string(now.getTime() / 1000 + 3600),
|
||||
],
|
||||
logType: '0',
|
||||
};
|
||||
@@ -757,7 +803,11 @@ const LogsTable = () => {
|
||||
let start_timestamp = timestamp2string(getTodayStartTimestamp());
|
||||
let end_timestamp = timestamp2string(now.getTime() / 1000 + 3600);
|
||||
|
||||
if (formValues.dateRange && Array.isArray(formValues.dateRange) && formValues.dateRange.length === 2) {
|
||||
if (
|
||||
formValues.dateRange &&
|
||||
Array.isArray(formValues.dateRange) &&
|
||||
formValues.dateRange.length === 2
|
||||
) {
|
||||
start_timestamp = formValues.dateRange[0];
|
||||
end_timestamp = formValues.dateRange[1];
|
||||
}
|
||||
@@ -935,27 +985,27 @@ const LogsTable = () => {
|
||||
key: t('日志详情'),
|
||||
value: other?.claude
|
||||
? renderClaudeLogContent(
|
||||
other?.model_ratio,
|
||||
other.completion_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other.cache_ratio || 1.0,
|
||||
other.cache_creation_ratio || 1.0,
|
||||
)
|
||||
other?.model_ratio,
|
||||
other.completion_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
other.cache_ratio || 1.0,
|
||||
other.cache_creation_ratio || 1.0,
|
||||
)
|
||||
: renderLogContent(
|
||||
other?.model_ratio,
|
||||
other.completion_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
false,
|
||||
1.0,
|
||||
undefined,
|
||||
other.web_search || false,
|
||||
other.web_search_call_count || 0,
|
||||
other.file_search || false,
|
||||
other.file_search_call_count || 0,
|
||||
),
|
||||
other?.model_ratio,
|
||||
other.completion_ratio,
|
||||
other.model_price,
|
||||
other.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
false,
|
||||
1.0,
|
||||
other.web_search || false,
|
||||
other.web_search_call_count || 0,
|
||||
other.file_search || false,
|
||||
other.file_search_call_count || 0,
|
||||
),
|
||||
});
|
||||
}
|
||||
if (logs[i].type === 2) {
|
||||
@@ -986,6 +1036,7 @@ const LogsTable = () => {
|
||||
other?.audio_ratio,
|
||||
other?.audio_completion_ratio,
|
||||
other?.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
other?.cache_tokens || 0,
|
||||
other?.cache_ratio || 1.0,
|
||||
);
|
||||
@@ -997,6 +1048,7 @@ const LogsTable = () => {
|
||||
other.model_price,
|
||||
other.completion_ratio,
|
||||
other.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
other.cache_tokens || 0,
|
||||
other.cache_ratio || 1.0,
|
||||
other.cache_creation_tokens || 0,
|
||||
@@ -1010,6 +1062,7 @@ const LogsTable = () => {
|
||||
other?.model_price,
|
||||
other?.completion_ratio,
|
||||
other?.group_ratio,
|
||||
other?.user_group_ratio,
|
||||
other?.cache_tokens || 0,
|
||||
other?.cache_ratio || 1.0,
|
||||
other?.image || false,
|
||||
@@ -1060,7 +1113,12 @@ const LogsTable = () => {
|
||||
} = getFormValues();
|
||||
|
||||
// 使用传入的 logType 或者表单中的 logType 或者状态中的 logType
|
||||
const currentLogType = customLogType !== null ? customLogType : formLogType !== undefined ? formLogType : logType;
|
||||
const currentLogType =
|
||||
customLogType !== null
|
||||
? customLogType
|
||||
: formLogType !== undefined
|
||||
? formLogType
|
||||
: logType;
|
||||
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
@@ -1087,7 +1145,7 @@ const LogsTable = () => {
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
setActivePage(page);
|
||||
loadLogs(page, pageSize).then((r) => { }); // 不传入logType,让其从表单获取最新值
|
||||
loadLogs(page, pageSize).then((r) => {}); // 不传入logType,让其从表单获取最新值
|
||||
};
|
||||
|
||||
const handlePageSizeChange = async (size) => {
|
||||
@@ -1202,9 +1260,9 @@ const LogsTable = () => {
|
||||
getFormApi={(api) => setFormApi(api)}
|
||||
onSubmit={refresh}
|
||||
allowEmpty={true}
|
||||
autoComplete="off"
|
||||
layout="vertical"
|
||||
trigger="change"
|
||||
autoComplete='off'
|
||||
layout='vertical'
|
||||
trigger='change'
|
||||
stopValidateWithError={false}
|
||||
>
|
||||
<div className='flex flex-col gap-4'>
|
||||
@@ -1288,12 +1346,24 @@ const LogsTable = () => {
|
||||
}, 0);
|
||||
}}
|
||||
>
|
||||
<Form.Select.Option value='0'>{t('全部')}</Form.Select.Option>
|
||||
<Form.Select.Option value='1'>{t('充值')}</Form.Select.Option>
|
||||
<Form.Select.Option value='2'>{t('消费')}</Form.Select.Option>
|
||||
<Form.Select.Option value='3'>{t('管理')}</Form.Select.Option>
|
||||
<Form.Select.Option value='4'>{t('系统')}</Form.Select.Option>
|
||||
<Form.Select.Option value='5'>{t('错误')}</Form.Select.Option>
|
||||
<Form.Select.Option value='0'>
|
||||
{t('全部')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='1'>
|
||||
{t('充值')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='2'>
|
||||
{t('消费')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='3'>
|
||||
{t('管理')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='4'>
|
||||
{t('系统')}
|
||||
</Form.Select.Option>
|
||||
<Form.Select.Option value='5'>
|
||||
{t('错误')}
|
||||
</Form.Select.Option>
|
||||
</Form.Select>
|
||||
</div>
|
||||
|
||||
@@ -1345,7 +1415,8 @@ const LogsTable = () => {
|
||||
{...(hasExpandableRows() && {
|
||||
expandedRowRender: expandRowRender,
|
||||
expandRowByClick: true,
|
||||
rowExpandable: (record) => expandData[record.key] && expandData[record.key].length > 0
|
||||
rowExpandable: (record) =>
|
||||
expandData[record.key] && expandData[record.key].length > 0,
|
||||
})}
|
||||
dataSource={logs}
|
||||
rowKey='key'
|
||||
@@ -1355,8 +1426,12 @@ const LogsTable = () => {
|
||||
size='middle'
|
||||
empty={
|
||||
<Empty
|
||||
image={<IllustrationNoResult style={{ width: 150, height: 150 }} />}
|
||||
darkModeImage={<IllustrationNoResultDark style={{ width: 150, height: 150 }} />}
|
||||
image={
|
||||
<IllustrationNoResult style={{ width: 150, height: 150 }} />
|
||||
}
|
||||
darkModeImage={
|
||||
<IllustrationNoResultDark style={{ width: 150, height: 150 }} />
|
||||
}
|
||||
description={t('搜索无结果')}
|
||||
style={{ padding: 30 }}
|
||||
/>
|
||||
|
||||
@@ -601,7 +601,7 @@ const LogsTable = () => {
|
||||
const [logs, setLogs] = useState([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [activePage, setActivePage] = useState(1);
|
||||
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||
const [logCount, setLogCount] = useState(0);
|
||||
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
||||
const [isModalOpenurl, setIsModalOpenurl] = useState(false);
|
||||
const [showBanner, setShowBanner] = useState(false);
|
||||
@@ -649,69 +649,53 @@ const LogsTable = () => {
|
||||
};
|
||||
};
|
||||
|
||||
const setLogsFormat = (logs) => {
|
||||
for (let i = 0; i < logs.length; i++) {
|
||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||
logs[i].key = '' + logs[i].id;
|
||||
}
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + pageSize);
|
||||
// console.log(logCount);
|
||||
const enrichLogs = (items) => {
|
||||
return items.map((log) => ({
|
||||
...log,
|
||||
timestamp2string: timestamp2string(log.created_at),
|
||||
key: '' + log.id,
|
||||
}));
|
||||
};
|
||||
|
||||
const loadLogs = async (startIdx, pageSize = ITEMS_PER_PAGE) => {
|
||||
setLoading(true);
|
||||
const syncPageData = (payload) => {
|
||||
const items = enrichLogs(payload.items || []);
|
||||
setLogs(items);
|
||||
setLogCount(payload.total || 0);
|
||||
setActivePage(payload.page || 1);
|
||||
setPageSize(payload.page_size || pageSize);
|
||||
};
|
||||
|
||||
let url = '';
|
||||
const loadLogs = async (page = 1, size = pageSize) => {
|
||||
setLoading(true);
|
||||
const { channel_id, mj_id, start_timestamp, end_timestamp } = getFormValues();
|
||||
let localStartTimestamp = Date.parse(start_timestamp);
|
||||
let localEndTimestamp = Date.parse(end_timestamp);
|
||||
if (isAdminUser) {
|
||||
url = `/api/mj/?p=${startIdx}&page_size=${pageSize}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
} else {
|
||||
url = `/api/mj/self/?p=${startIdx}&page_size=${pageSize}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const url = isAdminUser
|
||||
? `/api/mj/?p=${page}&page_size=${size}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`
|
||||
: `/api/mj/self/?p=${page}&page_size=${size}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
const res = await API.get(url);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
if (startIdx === 0) {
|
||||
setLogsFormat(data);
|
||||
} else {
|
||||
let newLogs = [...logs];
|
||||
newLogs.splice(startIdx * pageSize, data.length, ...data);
|
||||
setLogsFormat(newLogs);
|
||||
}
|
||||
syncPageData(data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const pageData = logs.slice(
|
||||
(activePage - 1) * pageSize,
|
||||
activePage * pageSize,
|
||||
);
|
||||
const pageData = logs;
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
setActivePage(page);
|
||||
if (page === Math.ceil(logs.length / pageSize) + 1) {
|
||||
// In this case we have to load more data and then append them.
|
||||
loadLogs(page - 1, pageSize).then((r) => { });
|
||||
}
|
||||
loadLogs(page, pageSize).then();
|
||||
};
|
||||
|
||||
const handlePageSizeChange = async (size) => {
|
||||
localStorage.setItem('mj-page-size', size + '');
|
||||
setPageSize(size);
|
||||
setActivePage(1);
|
||||
await loadLogs(0, size);
|
||||
await loadLogs(1, size);
|
||||
};
|
||||
|
||||
const refresh = async () => {
|
||||
// setLoading(true);
|
||||
setActivePage(1);
|
||||
await loadLogs(0, pageSize);
|
||||
await loadLogs(1, pageSize);
|
||||
};
|
||||
|
||||
const copyText = async (text) => {
|
||||
@@ -726,7 +710,7 @@ const LogsTable = () => {
|
||||
useEffect(() => {
|
||||
const localPageSize = parseInt(localStorage.getItem('mj-page-size')) || ITEMS_PER_PAGE;
|
||||
setPageSize(localPageSize);
|
||||
loadLogs(0, localPageSize).then();
|
||||
loadLogs(1, localPageSize).then();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -936,7 +920,7 @@ const LogsTable = () => {
|
||||
>
|
||||
<Table
|
||||
columns={getVisibleColumns()}
|
||||
dataSource={pageData}
|
||||
dataSource={logs}
|
||||
rowKey='key'
|
||||
loading={loading}
|
||||
scroll={{ x: 'max-content' }}
|
||||
@@ -962,9 +946,7 @@ const LogsTable = () => {
|
||||
total: logCount,
|
||||
pageSizeOptions: [10, 20, 50, 100],
|
||||
showSizeChanger: true,
|
||||
onPageSizeChange: (size) => {
|
||||
handlePageSizeChange(size);
|
||||
},
|
||||
onPageSizeChange: handlePageSizeChange,
|
||||
onPageChange: handlePageChange,
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -13,7 +13,8 @@ import {
|
||||
XCircle,
|
||||
Minus,
|
||||
HelpCircle,
|
||||
Coins
|
||||
Coins,
|
||||
Ticket
|
||||
} from 'lucide-react';
|
||||
|
||||
import { ITEMS_PER_PAGE } from '../../constants';
|
||||
@@ -58,7 +59,16 @@ function renderTimestamp(timestamp) {
|
||||
const RedemptionsTable = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const renderStatus = (status) => {
|
||||
const isExpired = (rec) => {
|
||||
return rec.status === 1 && rec.expired_time !== 0 && rec.expired_time < Math.floor(Date.now() / 1000);
|
||||
};
|
||||
|
||||
const renderStatus = (status, record) => {
|
||||
if (isExpired(record)) {
|
||||
return (
|
||||
<Tag color='orange' size='large' shape='circle' prefixIcon={<Minus size={14} />}>{t('已过期')}</Tag>
|
||||
);
|
||||
}
|
||||
switch (status) {
|
||||
case 1:
|
||||
return (
|
||||
@@ -101,7 +111,7 @@ const RedemptionsTable = () => {
|
||||
dataIndex: 'status',
|
||||
key: 'status',
|
||||
render: (text, record, index) => {
|
||||
return <div>{renderStatus(text)}</div>;
|
||||
return <div>{renderStatus(text, record)}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -124,6 +134,13 @@ const RedemptionsTable = () => {
|
||||
return <div>{renderTimestamp(text)}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('过期时间'),
|
||||
dataIndex: 'expired_time',
|
||||
render: (text) => {
|
||||
return <div>{text === 0 ? t('永不过期') : renderTimestamp(text)}</div>;
|
||||
},
|
||||
},
|
||||
{
|
||||
title: t('兑换人ID'),
|
||||
dataIndex: 'used_user_id',
|
||||
@@ -157,8 +174,7 @@ const RedemptionsTable = () => {
|
||||
}
|
||||
];
|
||||
|
||||
// 动态添加启用/禁用按钮
|
||||
if (record.status === 1) {
|
||||
if (record.status === 1 && !isExpired(record)) {
|
||||
moreMenuItems.push({
|
||||
node: 'item',
|
||||
name: t('禁用'),
|
||||
@@ -168,7 +184,7 @@ const RedemptionsTable = () => {
|
||||
manageRedemption(record.id, 'disable', record);
|
||||
},
|
||||
});
|
||||
} else {
|
||||
} else if (!isExpired(record)) {
|
||||
moreMenuItems.push({
|
||||
node: 'item',
|
||||
name: t('启用'),
|
||||
@@ -435,7 +451,7 @@ const RedemptionsTable = () => {
|
||||
};
|
||||
|
||||
const handleRow = (record, index) => {
|
||||
if (record.status !== 1) {
|
||||
if (record.status !== 1 || isExpired(record)) {
|
||||
return {
|
||||
style: {
|
||||
background: 'var(--semi-color-disabled-border)',
|
||||
@@ -450,7 +466,7 @@ const RedemptionsTable = () => {
|
||||
<div className="flex flex-col w-full">
|
||||
<div className="mb-2">
|
||||
<div className="flex items-center text-orange-500">
|
||||
<IconEyeOpened className="mr-2" />
|
||||
<Ticket size={16} className="mr-2" />
|
||||
<Text>{t('兑换码可以批量生成和分发,适合用于推广活动或批量充值。')}</Text>
|
||||
</div>
|
||||
</div>
|
||||
@@ -458,39 +474,66 @@ const RedemptionsTable = () => {
|
||||
<Divider margin="12px" />
|
||||
|
||||
<div className="flex flex-col md:flex-row justify-between items-center gap-4 w-full">
|
||||
<div className="flex gap-2 w-full md:w-auto order-2 md:order-1">
|
||||
<div className="flex flex-col sm:flex-row gap-2 w-full md:w-auto order-2 md:order-1">
|
||||
<div className="flex gap-2 w-full sm:w-auto">
|
||||
<Button
|
||||
theme='light'
|
||||
type='primary'
|
||||
icon={<IconPlus />}
|
||||
className="!rounded-full w-full sm:w-auto"
|
||||
onClick={() => {
|
||||
setEditingRedemption({
|
||||
id: undefined,
|
||||
});
|
||||
setShowEdit(true);
|
||||
}}
|
||||
>
|
||||
{t('添加兑换码')}
|
||||
</Button>
|
||||
<Button
|
||||
type='warning'
|
||||
icon={<IconCopy />}
|
||||
className="!rounded-full w-full sm:w-auto"
|
||||
onClick={async () => {
|
||||
if (selectedKeys.length === 0) {
|
||||
showError(t('请至少选择一个兑换码!'));
|
||||
return;
|
||||
}
|
||||
let keys = '';
|
||||
for (let i = 0; i < selectedKeys.length; i++) {
|
||||
keys +=
|
||||
selectedKeys[i].name + ' ' + selectedKeys[i].key + '\n';
|
||||
}
|
||||
await copyText(keys);
|
||||
}}
|
||||
>
|
||||
{t('复制所选兑换码到剪贴板')}
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
theme='light'
|
||||
type='primary'
|
||||
icon={<IconPlus />}
|
||||
className="!rounded-full w-full md:w-auto"
|
||||
type='danger'
|
||||
icon={<IconDelete />}
|
||||
className="!rounded-full w-full sm:w-auto"
|
||||
onClick={() => {
|
||||
setEditingRedemption({
|
||||
id: undefined,
|
||||
Modal.confirm({
|
||||
title: t('确定清除所有失效兑换码?'),
|
||||
content: t('将删除已使用、已禁用及过期的兑换码,此操作不可撤销。'),
|
||||
onOk: async () => {
|
||||
setLoading(true);
|
||||
const res = await API.delete('/api/redemption/invalid');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(t('已删除 {{count}} 条失效兑换码', { count: data }));
|
||||
await refresh();
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
},
|
||||
});
|
||||
setShowEdit(true);
|
||||
}}
|
||||
>
|
||||
{t('添加兑换码')}
|
||||
</Button>
|
||||
<Button
|
||||
type='warning'
|
||||
icon={<IconCopy />}
|
||||
className="!rounded-full w-full md:w-auto"
|
||||
onClick={async () => {
|
||||
if (selectedKeys.length === 0) {
|
||||
showError(t('请至少选择一个兑换码!'));
|
||||
return;
|
||||
}
|
||||
let keys = '';
|
||||
for (let i = 0; i < selectedKeys.length; i++) {
|
||||
keys +=
|
||||
selectedKeys[i].name + ' ' + selectedKeys[i].key + '\n';
|
||||
}
|
||||
await copyText(keys);
|
||||
}}
|
||||
>
|
||||
{t('复制所选兑换码到剪贴板')}
|
||||
{t('清除失效兑换码')}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user