mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-31 13:18:19 +00:00
Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
954ed893b9 | ||
|
|
95e84f2bb1 | ||
|
|
f3f8cdc4a3 | ||
|
|
1244963e81 | ||
|
|
8f36a995ef | ||
|
|
daf0be4915 | ||
|
|
33b93e1e9a | ||
|
|
f3124e7252 | ||
|
|
ebedf2cb7e | ||
|
|
d68a50125b | ||
|
|
18e9d37057 | ||
|
|
e2d994d73a | ||
|
|
29dbdf01f0 | ||
|
|
64862cd634 | ||
|
|
a5e0f86c65 | ||
|
|
2a995a5da2 | ||
|
|
bba6174745 | ||
|
|
b91a269ddb | ||
|
|
1c2bba8979 | ||
|
|
ce05e7dd86 | ||
|
|
bf8794d257 | ||
|
|
c09df83f34 | ||
|
|
79b1201f47 | ||
|
|
b3fc335908 | ||
|
|
290fed3841 | ||
|
|
aa6d623981 | ||
|
|
afbd585048 | ||
|
|
5c747dfee2 | ||
|
|
89dd0e05a0 | ||
|
|
7e4fe14871 | ||
|
|
85411bc167 | ||
|
|
b86bc169a5 | ||
|
|
bdd611fd33 | ||
|
|
1a8a24698f | ||
|
|
c09f3b9282 | ||
|
|
e8bcf60f0a | ||
|
|
14592f9758 | ||
|
|
4036355fae | ||
|
|
fa2efb7357 | ||
|
|
7a8344c40a | ||
|
|
6ad2544415 | ||
|
|
7cd1261a81 | ||
|
|
45ed973f1c | ||
|
|
07fe5a02af | ||
|
|
7a4969c238 | ||
|
|
ad6842da7f | ||
|
|
3c52a0991f | ||
|
|
fd4ef086dc | ||
|
|
7c4719b6ee | ||
|
|
b9d040cf52 | ||
|
|
6e8ff8c057 | ||
|
|
676dc95793 | ||
|
|
1c06bddafe | ||
|
|
3475643257 |
6
LICENSE
6
LICENSE
@@ -1,10 +1,6 @@
|
||||
MIT License
|
||||
|
||||
New API
|
||||
Copyright (c) 2023 CalciumIon
|
||||
|
||||
Based on One API
|
||||
Copyright (c) 2023 JustSong
|
||||
Copyright (c) 2024 Calcium-Ion
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
14
README.md
14
README.md
@@ -32,6 +32,11 @@
|
||||
7. 支持 gpt-4-1106-vision-preview,dall-e-3,tts-1
|
||||
8. 支持第三方模型 **gps** (gpt-4-gizmo-*),在渠道中添加自定义模型gpt-4-gizmo-*即可
|
||||
9. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
|
||||
11. 支持gemini-pro,gemini-pro-vision模型
|
||||
12. 支持渠道**加权随机**
|
||||
13. 数据看板
|
||||
14. 可设置令牌能调用的模型
|
||||
|
||||
## 部署
|
||||
### 基于 Docker 进行部署
|
||||
@@ -44,9 +49,11 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||
```
|
||||
|
||||
## 交流群
|
||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="500">
|
||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
|
||||
|
||||
## 界面截图
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
@@ -54,8 +61,11 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
|
||||

|
||||

|
||||
夜间模式
|
||||

|
||||
|
||||

|
||||

|
||||
|
||||
## Star History
|
||||
|
||||
|
||||
[](https://star-history.com/#Calcium-Ion/new-api&Date)
|
||||
|
||||
@@ -24,6 +24,9 @@ var ChatLink = ""
|
||||
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||
var DisplayInCurrencyEnabled = true
|
||||
var DisplayTokenStatEnabled = true
|
||||
var DrawingEnabled = true
|
||||
var DataExportEnabled = true
|
||||
var DataExportInterval = 5 // unit: minute
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, error) {
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
base64String = base64String[idx+1:]
|
||||
@@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) {
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to decode base64 string")
|
||||
return image.Config{}, err
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
// 创建一个bytes.Buffer用于存储解码后的数据
|
||||
reader := bytes.NewReader(decodedData)
|
||||
config, err := getImageConfig(reader)
|
||||
return config, err
|
||||
config, format, err := getImageConfig(reader)
|
||||
return config, format, err
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, error) {
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := http.Head(url)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
_, err = buffer.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mimeType = resp.Header.Get("Content-Type")
|
||||
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := http.Get(imageUrl)
|
||||
if err != nil {
|
||||
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, err
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
// 限制读取的字节数,防止下载整个图片
|
||||
@@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) {
|
||||
// log.Fatal(err)
|
||||
//}
|
||||
//log.Printf("%x", data)
|
||||
config, err := getImageConfig(limitReader)
|
||||
config, format, err := getImageConfig(limitReader)
|
||||
response.Body.Close()
|
||||
return config, err
|
||||
return config, format, err
|
||||
}
|
||||
|
||||
func getImageConfig(reader io.Reader) (image.Config, error) {
|
||||
func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, _, err := image.DecodeConfig(reader)
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
SysLog(err.Error())
|
||||
@@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
SysLog(err.Error())
|
||||
}
|
||||
format = "webp"
|
||||
}
|
||||
if err != nil {
|
||||
return image.Config{}, err
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
return config, nil
|
||||
return config, format, nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ var (
|
||||
)
|
||||
|
||||
func printHelp() {
|
||||
fmt.Println("One API " + Version + " - All in one API service for OpenAI API.")
|
||||
fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
|
||||
fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
|
||||
fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
// 1 === $0.002 / 1K tokens
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
var ModelRatio = map[string]float64{
|
||||
"midjourney": 50,
|
||||
//"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"gpt-4": 15,
|
||||
"gpt-4-0314": 15,
|
||||
@@ -32,6 +32,8 @@ var ModelRatio = map[string]float64{
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
@@ -62,6 +64,7 @@ var ModelRatio = map[string]float64{
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
@@ -77,6 +80,41 @@ var ModelRatio = map[string]float64{
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
}
|
||||
|
||||
var ModelPrice = map[string]float64{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
}
|
||||
|
||||
func ModelPrice2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelPrice)
|
||||
if err != nil {
|
||||
SysError("error marshalling model price: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
ModelPrice = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
|
||||
}
|
||||
|
||||
func GetModelPrice(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
price, ok := ModelPrice[name]
|
||||
if !ok {
|
||||
SysError("model price not found: " + name)
|
||||
return -1
|
||||
}
|
||||
return price
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelRatio)
|
||||
if err != nil {
|
||||
|
||||
@@ -168,6 +168,11 @@ func GetRandomString(length int) string {
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func GetRandomInt(max int) int {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
return rand.Intn(max)
|
||||
}
|
||||
|
||||
func GetTimestamp() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
@@ -16,8 +16,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTask() {
|
||||
/*func UpdateMidjourneyTask() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
imageModel := "midjourney"
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) != 0 {
|
||||
log.Printf("检测到未完成的任务数有: %v", len(tasks))
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
for _, task := range tasks {
|
||||
log.Printf("未完成的任务信息: %v", task)
|
||||
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
|
||||
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask: %v", err)
|
||||
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
|
||||
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
|
||||
task.Status = "FAILURE"
|
||||
task.Progress = "100%"
|
||||
err := task.Update()
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
|
||||
log.Printf("requestUrl: %s", requestUrl)
|
||||
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
|
||||
|
||||
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
log.Println(task.MjId + " 构建失败," + task.FailReason)
|
||||
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
@@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
|
||||
if err != nil {
|
||||
log.Println("fail to increase user quota")
|
||||
}
|
||||
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, 1, logContent)
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
//revocer
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
||||
}
|
||||
}()
|
||||
imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Midjourney)
|
||||
for _, task := range tasks {
|
||||
if task.MjId == "" {
|
||||
continue
|
||||
}
|
||||
taskM[task.MjId] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
continue
|
||||
}
|
||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []Midjourney
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
req.Body.Close()
|
||||
cancel()
|
||||
|
||||
for _, responseItem := range responseItems {
|
||||
task := taskM[responseItem.MjId]
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
task.State = responseItem.State
|
||||
task.SubmitTime = responseItem.SubmitTime
|
||||
task.StartTime = responseItem.StartTime
|
||||
task.FinishTime = responseItem.FinishTime
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
ratio := modelRatio * groupRatio
|
||||
quota := int(ratio * 1 * 1000)
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
|
||||
if oldTask.Code != 1 {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != newTask.Progress {
|
||||
return true
|
||||
}
|
||||
if oldTask.PromptEn != newTask.PromptEn {
|
||||
return true
|
||||
}
|
||||
if oldTask.State != newTask.State {
|
||||
return true
|
||||
}
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.ImageUrl != newTask.ImageUrl {
|
||||
return true
|
||||
}
|
||||
if oldTask.Status != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
|
||||
@@ -34,6 +34,8 @@ func GetStatus(c *gin.Context) {
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
},
|
||||
})
|
||||
return
|
||||
|
||||
@@ -360,6 +360,24 @@ func init() {
|
||||
Root: "code-davinci-edit-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "babbage-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "babbage-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "davinci-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "davinci-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-instant-1",
|
||||
Object: "model",
|
||||
@@ -432,6 +450,15 @@ func init() {
|
||||
Root: "gemini-pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro-vision",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro-vision",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_turbo",
|
||||
Object: "model",
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
@@ -97,6 +101,31 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
|
||||
if part.Type == ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
|
||||
@@ -55,9 +55,18 @@ type MidjourneyWithoutStatus struct {
|
||||
ChannelId int `json:"channel_id"`
|
||||
}
|
||||
|
||||
var DefaultModelPrice = map[string]float64{
|
||||
"mj_imagine": 0.1,
|
||||
"mj_variation": 0.1,
|
||||
"mj_reroll": 0.1,
|
||||
"mj_blend": 0.1,
|
||||
"mj_describe": 0.05,
|
||||
"mj_upscale": 0.05,
|
||||
}
|
||||
|
||||
func RelayMidjourneyImage(c *gin.Context) {
|
||||
taskId := c.Param("id")
|
||||
midjourneyTask := model.GetByMJId(taskId)
|
||||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||||
if midjourneyTask == nil {
|
||||
c.JSON(400, gin.H{
|
||||
"error": "midjourney_task_not_found",
|
||||
@@ -71,14 +80,27 @@ func RelayMidjourneyImage(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
responseBody, _ := io.ReadAll(resp.Body)
|
||||
c.JSON(resp.StatusCode, gin.H{
|
||||
"error": string(responseBody),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Header("Content-Type", "image/jpeg")
|
||||
//c.HeaderBar("Content-Length", string(rune(len(data))))
|
||||
c.Data(http.StatusOK, "image/jpeg", data)
|
||||
// 从Content-Type头获取MIME类型
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
// 如果无法确定内容类型,则默认为jpeg
|
||||
contentType = "image/jpeg"
|
||||
}
|
||||
// 设置响应的内容类型
|
||||
c.Writer.Header().Set("Content-Type", contentType)
|
||||
// 将图片流式传输到响应体
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
log.Println("Failed to stream image:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||
@@ -92,7 +114,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask := model.GetByMJId(midjRequest.MjId)
|
||||
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||||
if midjourneyTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -121,16 +143,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||
return nil
|
||||
}
|
||||
|
||||
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
taskId := c.Param("id")
|
||||
originTask := model.GetByMJId(taskId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
var midjourneyTask Midjourney
|
||||
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
|
||||
midjourneyTask.MjId = originTask.MjId
|
||||
midjourneyTask.Progress = originTask.Progress
|
||||
midjourneyTask.PromptEn = originTask.PromptEn
|
||||
@@ -150,14 +163,65 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
midjourneyTask.Action = originTask.Action
|
||||
midjourneyTask.Description = originTask.Description
|
||||
midjourneyTask.Prompt = originTask.Prompt
|
||||
jsonMap, err := json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
return
|
||||
}
|
||||
|
||||
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
userId := c.GetInt("id")
|
||||
var err error
|
||||
var respBody []byte
|
||||
switch relayMode {
|
||||
case RelayModeMidjourneyTaskFetch:
|
||||
taskId := c.Param("id")
|
||||
originTask := model.GetByMJId(userId, taskId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
respBody, err = json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
case RelayModeMidjourneyTaskFetchByCondition:
|
||||
var condition = struct {
|
||||
IDs []string `json:"ids"`
|
||||
}{}
|
||||
err = c.BindJSON(&condition)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
var tasks []Midjourney
|
||||
if len(condition.IDs) != 0 {
|
||||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||||
for _, originTask := range originTasks {
|
||||
midjourneyTask := getMidjourneyTaskModel(c, originTask)
|
||||
tasks = append(tasks, midjourneyTask)
|
||||
}
|
||||
}
|
||||
if tasks == nil {
|
||||
tasks = make([]Midjourney, 0)
|
||||
}
|
||||
respBody, err = json.Marshal(tasks)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -167,6 +231,18 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// type 1 根据 mode 价格不同
|
||||
MJSubmitActionImagine = "IMAGINE"
|
||||
MJSubmitActionVariation = "VARIATION" //变换
|
||||
MJSubmitActionBlend = "BLEND" //混图
|
||||
|
||||
MJSubmitActionReroll = "REROLL" //重新生成
|
||||
// type 2 固定价格
|
||||
MJSubmitActionDescribe = "DESCRIBE"
|
||||
MJSubmitActionUpscale = "UPSCALE" // 放大
|
||||
)
|
||||
|
||||
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
imageModel := "midjourney"
|
||||
|
||||
@@ -186,6 +262,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||||
if midjRequest.Prompt == "" {
|
||||
return &MidjourneyResponse{
|
||||
@@ -199,7 +276,45 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||
midjRequest.Action = "BLEND"
|
||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||
originTask := model.GetByMJId(midjRequest.TaskId)
|
||||
mjId := ""
|
||||
if relayMode == RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "taskId_is_required",
|
||||
}
|
||||
} else if midjRequest.Action == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "action_is_required",
|
||||
}
|
||||
} else if midjRequest.Index == 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "index_can_only_be_1_2_3_4",
|
||||
}
|
||||
}
|
||||
//action = midjRequest.Action
|
||||
mjId = midjRequest.TaskId
|
||||
} else if relayMode == RelayModeMidjourneySimpleChange {
|
||||
if midjRequest.Content == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_is_required",
|
||||
}
|
||||
}
|
||||
params := convertSimpleChangeParams(midjRequest.Content)
|
||||
if params == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "content_parse_failed",
|
||||
}
|
||||
}
|
||||
mjId = params.ID
|
||||
midjRequest.Action = params.Action
|
||||
}
|
||||
|
||||
originTask := model.GetByMJId(userId, mjId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
@@ -229,23 +344,6 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
log.Printf("检测到此操作为放大、变换,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
} else if relayMode == RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "taskId_is_required",
|
||||
}
|
||||
} else if midjRequest.Action == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "action_is_required",
|
||||
}
|
||||
} else if midjRequest.Index == 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "index_can_only_be_1_2_3_4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
@@ -292,18 +390,27 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
|
||||
sizeRatio := 1.0
|
||||
if midjRequest.Action == "UPSCALE" {
|
||||
sizeRatio = 0.2
|
||||
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
|
||||
modelPrice := common.GetModelPrice(mjAction)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if modelPrice == -1 {
|
||||
defaultPrice, ok := DefaultModelPrice[mjAction]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(ratio * sizeRatio * 1000)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
return &MidjourneyResponse{
|
||||
@@ -369,7 +476,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
@@ -504,3 +611,38 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type taskChangeParams struct {
|
||||
ID string
|
||||
Action string
|
||||
Index int
|
||||
}
|
||||
|
||||
func convertSimpleChangeParams(content string) *taskChangeParams {
|
||||
split := strings.Split(content, " ")
|
||||
if len(split) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
action := strings.ToLower(split[1])
|
||||
changeParams := &taskChangeParams{}
|
||||
changeParams.ID = split[0]
|
||||
|
||||
if action[0] == 'u' {
|
||||
changeParams.Action = "UPSCALE"
|
||||
} else if action[0] == 'v' {
|
||||
changeParams.Action = "VARIATION"
|
||||
} else if action == "r" {
|
||||
changeParams.Action = "REROLL"
|
||||
return changeParams
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(action[1:2])
|
||||
if err != nil || index < 1 || index > 4 {
|
||||
return nil
|
||||
}
|
||||
changeParams.Index = index
|
||||
return changeParams
|
||||
}
|
||||
|
||||
@@ -231,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
case RelayModeModerations:
|
||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
||||
}
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
}
|
||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||
modelPrice := common.GetModelPrice(textRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
if modelPrice == -1 {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
}
|
||||
modelRatio = common.GetModelRatio(textRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -404,13 +414,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypeGemini:
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if apiType != APITypeGemini {
|
||||
// 设置公共头部...
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
|
||||
resp, err = httpClient.Do(req)
|
||||
@@ -447,15 +462,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
go func() {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
quota := 0
|
||||
if modelPrice == -1 {
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
@@ -472,10 +491,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
|
||||
}
|
||||
logModel := textRequest.Model
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
}
|
||||
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
//if quota != 0 {
|
||||
|
||||
@@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
|
||||
}
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
|
||||
config, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
|
||||
|
||||
shortSide := config.Width
|
||||
otherSide := config.Height
|
||||
log.Printf("width: %d, height: %d", config.Width, config.Height)
|
||||
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
|
||||
// 缩放倍数
|
||||
scale := 1.0
|
||||
if config.Height < shortSide {
|
||||
@@ -194,12 +195,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
||||
}
|
||||
|
||||
func countTokenInput(input any, model string) int {
|
||||
switch input.(type) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return countTokenText(input.(string), model)
|
||||
return countTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range input.([]string) {
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return countTokenText(text, model)
|
||||
|
||||
@@ -29,6 +29,60 @@ type MessageImageUrl struct {
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
return contentList
|
||||
}
|
||||
var arrayContent []json.RawMessage
|
||||
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
||||
for _, contentItem := range arrayContent {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
detail, ok := subObj["detail"]
|
||||
if ok {
|
||||
subObj["detail"] = detail.(string)
|
||||
} else {
|
||||
subObj["detail"] = "auto"
|
||||
}
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
Detail: subObj["detail"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
@@ -41,8 +95,10 @@ const (
|
||||
RelayModeMidjourneyDescribe
|
||||
RelayModeMidjourneyBlend
|
||||
RelayModeMidjourneyChange
|
||||
RelayModeMidjourneySimpleChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
RelayModeMidjourneyTaskFetchByCondition
|
||||
RelayModeAudio
|
||||
)
|
||||
|
||||
@@ -209,6 +265,7 @@ type MidjourneyRequest struct {
|
||||
State string `json:"state"`
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
@@ -288,14 +345,19 @@ func RelayMidjourney(c *gin.Context) {
|
||||
relayMode = RelayModeMidjourneyNotify
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
|
||||
relayMode = RelayModeMidjourneyTaskFetch
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
|
||||
relayMode = RelayModeMidjourneyTaskFetchByCondition
|
||||
}
|
||||
|
||||
var err *MidjourneyResponse
|
||||
switch relayMode {
|
||||
case RelayModeMidjourneyNotify:
|
||||
err = relayMidjourneyNotify(c)
|
||||
case RelayModeMidjourneyTaskFetch:
|
||||
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
|
||||
err = relayMidjourneyTask(c, relayMode)
|
||||
default:
|
||||
err = relayMidjourneySubmit(c, relayMode)
|
||||
|
||||
@@ -217,6 +217,8 @@ func UpdateToken(c *gin.Context) {
|
||||
cleanToken.ExpiredTime = token.ExpiredTime
|
||||
cleanToken.RemainQuota = token.RemainQuota
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
|
||||
cleanToken.ModelLimits = token.ModelLimits
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
|
||||
48
controller/usedata.go
Normal file
48
controller/usedata.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func GetAllQuotaDates(c *gin.Context) {
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
username := c.Query("username")
|
||||
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": dates,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserQuotaDates(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
|
||||
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": dates,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -152,6 +152,21 @@ func Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if exist {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名已存在,或已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
affCode := user.AffCode // this code is the inviter's code, not the user's own code
|
||||
inviterId, _ := model.GetUserIdByAffCode(affCode)
|
||||
cleanUser := model.User{
|
||||
@@ -525,7 +540,7 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.DeleteUserById(id)
|
||||
err = model.HardDeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -632,7 +647,7 @@ func ManageUser(c *gin.Context) {
|
||||
Username: req.Username,
|
||||
}
|
||||
// Fill attributes
|
||||
model.DB.Where(&user).First(&user)
|
||||
model.DB.Unscoped().Where(&user).First(&user)
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -668,7 +683,7 @@ func ManageUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.Delete(); err != nil {
|
||||
if err := user.HardDelete(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
12
go.mod
12
go.mod
@@ -19,7 +19,7 @@ require (
|
||||
github.com/samber/lo v1.38.1
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.17.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -44,13 +44,14 @@ require (
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.1 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@@ -64,8 +65,9 @@ require (
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
14
go.sum
14
go.sum
@@ -81,6 +81,10 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
|
||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
|
||||
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
|
||||
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
@@ -108,6 +112,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
@@ -172,11 +178,15 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -189,12 +199,16 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
|
||||
18
main.go
18
main.go
@@ -27,7 +27,7 @@ var indexPage []byte
|
||||
|
||||
func main() {
|
||||
common.SetupLogger()
|
||||
common.SysLog("One API " + common.Version + " started")
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
@@ -67,6 +67,10 @@ func main() {
|
||||
go model.SyncOptions(common.SyncFrequency)
|
||||
go model.SyncChannelCache(common.SyncFrequency)
|
||||
}
|
||||
|
||||
// 数据看板
|
||||
go model.UpdateQuotaData()
|
||||
|
||||
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
|
||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
|
||||
if err != nil {
|
||||
@@ -81,7 +85,7 @@ func main() {
|
||||
}
|
||||
go controller.AutomaticallyTestChannels(frequency)
|
||||
}
|
||||
go controller.UpdateMidjourneyTask()
|
||||
go controller.UpdateMidjourneyTaskBulk()
|
||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||
common.BatchUpdateEnabled = true
|
||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||
@@ -100,7 +104,15 @@ func main() {
|
||||
|
||||
// Initialize HTTP server
|
||||
server := gin.New()
|
||||
server.Use(gin.Recovery())
|
||||
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
|
||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||
"type": "new_api_panic",
|
||||
},
|
||||
})
|
||||
}))
|
||||
// This will cause SSE not to work!!!
|
||||
//server.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
server.Use(middleware.RequestId())
|
||||
|
||||
@@ -115,6 +115,12 @@ func TokenAuth() func(c *gin.Context) {
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
if token.ModelLimitsEnabled {
|
||||
c.Set("token_model_limit_enabled", true)
|
||||
c.Set("token_model_limit", token.GetModelLimitsMap())
|
||||
} else {
|
||||
c.Set("token_model_limit_enabled", false)
|
||||
}
|
||||
requestURL := c.Request.URL.String()
|
||||
consumeQuota := true
|
||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
||||
|
||||
@@ -50,7 +50,7 @@ func Distribute() func(c *gin.Context) {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求: "+err.Error())
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
@@ -77,6 +77,27 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
if modelLimitEnable {
|
||||
s, ok := c.Get("token_model_limit")
|
||||
var tokenModelLimit map[string]bool
|
||||
if ok {
|
||||
tokenModelLimit = s.(map[string]bool)
|
||||
} else {
|
||||
tokenModelLimit = map[string]bool{}
|
||||
}
|
||||
if tokenModelLimit != nil {
|
||||
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
|
||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// token model limit is empty, all models are not allowed
|
||||
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
|
||||
return
|
||||
}
|
||||
}
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
|
||||
28
middleware/recover.go
Normal file
28
middleware/recover.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func RelayPanicRecover() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
||||
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
|
||||
"type": "new_api_panic",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
|
||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
|
||||
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
|
||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||
Weight uint `json:"weight" gorm:"default:0;index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
@@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
var abilities []Ability
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
@@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
} else {
|
||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channel := Channel{}
|
||||
channel.Id = ability.ChannelId
|
||||
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
|
||||
if len(abilities) > 0 {
|
||||
// Randomly choose one
|
||||
weightSum := uint(0)
|
||||
for _, ability_ := range abilities {
|
||||
weightSum += ability_.Weight
|
||||
}
|
||||
if weightSum == 0 {
|
||||
// All weight is 0, randomly choose one
|
||||
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
|
||||
} else {
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight)
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
@@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
ChannelId: channel.Id,
|
||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||
Priority: channel.Priority,
|
||||
Weight: uint(channel.GetWeight()),
|
||||
}
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
@@ -149,10 +150,12 @@ func InitChannelCache() {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
newChannelsIDM := make(map[int]*Channel)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
newChannelsIDM[channel.Id] = channel
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
@@ -177,6 +180,7 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelsIDM = newChannelsIDM
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
@@ -194,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
}
|
||||
@@ -214,6 +219,41 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
return channels[idx], nil
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range channels[:endIdx] {
|
||||
totalWeight += channel.GetWeight()
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
// If all weights are 0, select a channel randomly
|
||||
return channels[rand.Intn(endIdx)], nil
|
||||
}
|
||||
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range channels[:endIdx] {
|
||||
randomWeight -= channel.GetWeight()
|
||||
if randomWeight <= 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
// return the last channel if no channel is found
|
||||
return channels[endIdx-1], nil
|
||||
}
|
||||
|
||||
func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetChannelById(id, true)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type Channel struct {
|
||||
Balance float64 `json:"balance"` // in USD
|
||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||
Models string `json:"models"`
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
@@ -113,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
|
||||
return *channel.Priority
|
||||
}
|
||||
|
||||
func (channel *Channel) GetWeight() int {
|
||||
if channel.Weight == nil {
|
||||
return 0
|
||||
}
|
||||
return int(*channel.Weight)
|
||||
}
|
||||
|
||||
func (channel *Channel) GetBaseURL() string {
|
||||
if channel.BaseURL == nil {
|
||||
return ""
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id;index:idx_created_at_id,priority:1"`
|
||||
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
@@ -59,9 +59,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
if !common.LogConsumeEnabled {
|
||||
return
|
||||
}
|
||||
username := GetUsernameById(userId)
|
||||
log := &Log{
|
||||
UserId: userId,
|
||||
Username: GetUsernameById(userId),
|
||||
Username: username,
|
||||
CreatedAt: common.GetTimestamp(),
|
||||
Type: LogTypeConsume,
|
||||
Content: content,
|
||||
@@ -77,6 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
||||
if err != nil {
|
||||
common.LogError(ctx, "failed to record log: "+err.Error())
|
||||
}
|
||||
if common.DataExportEnabled {
|
||||
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp())
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
|
||||
|
||||
@@ -52,6 +52,14 @@ func chooseDB() (*gorm.DB, error) {
|
||||
}
|
||||
// Use MySQL
|
||||
common.SysLog("using MySQL as database")
|
||||
// check parseTime
|
||||
if !strings.Contains(dsn, "parseTime") {
|
||||
if strings.Contains(dsn, "?") {
|
||||
dsn += "&parseTime=true"
|
||||
} else {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
@@ -119,6 +127,10 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.AutoMigrate(&QuotaData{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
return err
|
||||
|
||||
@@ -96,7 +96,7 @@ func GetAllUnFinishTasks() []*Midjourney {
|
||||
return tasks
|
||||
}
|
||||
|
||||
func GetByMJId(mjId string) *Midjourney {
|
||||
func GetByOnlyMJId(mjId string) *Midjourney {
|
||||
var mj *Midjourney
|
||||
var err error
|
||||
err = DB.Where("mj_id = ?", mjId).First(&mj).Error
|
||||
@@ -106,6 +106,26 @@ func GetByMJId(mjId string) *Midjourney {
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetByMJId(userId int, mjId string) *Midjourney {
|
||||
var mj *Midjourney
|
||||
var err error
|
||||
err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
|
||||
var mj []*Midjourney
|
||||
var err error
|
||||
err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetMjByuId(id int) *Midjourney {
|
||||
var mj *Midjourney
|
||||
var err error
|
||||
@@ -131,3 +151,9 @@ func (midjourney *Midjourney) Update() error {
|
||||
err = DB.Save(midjourney).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("mj_id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
@@ -37,6 +37,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
|
||||
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
|
||||
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
|
||||
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
|
||||
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
|
||||
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
|
||||
@@ -70,11 +72,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
@@ -156,6 +160,12 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.LogConsumeEnabled = boolValue
|
||||
case "DisplayInCurrencyEnabled":
|
||||
common.DisplayInCurrencyEnabled = boolValue
|
||||
case "DisplayTokenStatEnabled":
|
||||
common.DisplayTokenStatEnabled = boolValue
|
||||
case "DrawingEnabled":
|
||||
common.DrawingEnabled = boolValue
|
||||
case "DataExportEnabled":
|
||||
common.DataExportEnabled = boolValue
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -216,10 +226,14 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.PreConsumedQuota, _ = strconv.Atoi(value)
|
||||
case "RetryTimes":
|
||||
common.RetryTimes, _ = strconv.Atoi(value)
|
||||
case "DataExportInterval":
|
||||
common.DataExportInterval, _ = strconv.Atoi(value)
|
||||
case "ModelRatio":
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = common.UpdateGroupRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = common.UpdateModelPriceByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
case "ChatLink":
|
||||
|
||||
@@ -10,17 +10,19 @@ import (
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int `json:"remain_quota" gorm:"default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
|
||||
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
|
||||
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
|
||||
}
|
||||
|
||||
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
||||
@@ -107,7 +109,7 @@ func (token *Token) Insert() error {
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
var err error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -122,6 +124,36 @@ func (token *Token) Delete() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (token *Token) IsModelLimitsEnabled() bool {
|
||||
return token.ModelLimitsEnabled
|
||||
}
|
||||
|
||||
func (token *Token) GetModelLimits() []string {
|
||||
if token.ModelLimits == "" {
|
||||
return []string{}
|
||||
}
|
||||
return strings.Split(token.ModelLimits, ",")
|
||||
}
|
||||
|
||||
func (token *Token) GetModelLimitsMap() map[string]bool {
|
||||
limits := token.GetModelLimits()
|
||||
limitsMap := make(map[string]bool)
|
||||
for _, limit := range limits {
|
||||
limitsMap[limit] = true
|
||||
}
|
||||
return limitsMap
|
||||
}
|
||||
|
||||
func DisableModelLimits(tokenId int) error {
|
||||
token, err := GetTokenById(tokenId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
token.ModelLimitsEnabled = false
|
||||
token.ModelLimits = ""
|
||||
return token.Update()
|
||||
}
|
||||
|
||||
func DeleteTokenById(id int, userId int) (err error) {
|
||||
// Why we need userId here? In case user want to delete other's token.
|
||||
if id == 0 || userId == 0 {
|
||||
|
||||
115
model/usedata.go
Normal file
115
model/usedata.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// QuotaData 柱状图数据
|
||||
type QuotaData struct {
|
||||
Id int `json:"id"`
|
||||
UserID int `json:"user_id" gorm:"index"`
|
||||
Username string `json:"username" gorm:"index:idx_qdt_model_user_name,priority:2;size:64;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index:idx_qdt_model_user_name,priority:1;size:64;default:''"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_qdt_created_at,priority:2"`
|
||||
Count int `json:"count" gorm:"default:0"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
}
|
||||
|
||||
func UpdateQuotaData() {
|
||||
// recover
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
|
||||
}
|
||||
}()
|
||||
for {
|
||||
if common.DataExportEnabled {
|
||||
common.SysLog("正在更新数据看板数据...")
|
||||
SaveQuotaDataCache()
|
||||
}
|
||||
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
var CacheQuotaData = make(map[string]*QuotaData)
|
||||
var CacheQuotaDataLock = sync.Mutex{}
|
||||
|
||||
func LogQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64) {
|
||||
// 只精确到小时
|
||||
createdAt = createdAt - (createdAt % 3600)
|
||||
key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
|
||||
quotaData, ok := CacheQuotaData[key]
|
||||
if ok {
|
||||
quotaData.Count += 1
|
||||
quotaData.Quota += quota
|
||||
} else {
|
||||
quotaData = &QuotaData{
|
||||
UserID: userId,
|
||||
Username: username,
|
||||
ModelName: modelName,
|
||||
CreatedAt: createdAt,
|
||||
Count: 1,
|
||||
Quota: quota,
|
||||
}
|
||||
}
|
||||
CacheQuotaData[key] = quotaData
|
||||
}
|
||||
|
||||
func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64) {
|
||||
CacheQuotaDataLock.Lock()
|
||||
defer CacheQuotaDataLock.Unlock()
|
||||
LogQuotaDataCache(userId, username, modelName, quota, createdAt)
|
||||
}
|
||||
|
||||
func SaveQuotaDataCache() {
|
||||
CacheQuotaDataLock.Lock()
|
||||
defer CacheQuotaDataLock.Unlock()
|
||||
size := len(CacheQuotaData)
|
||||
// 如果缓存中有数据,就保存到数据库中
|
||||
// 1. 先查询数据库中是否有数据
|
||||
// 2. 如果有数据,就更新数据
|
||||
// 3. 如果没有数据,就插入数据
|
||||
for _, quotaData := range CacheQuotaData {
|
||||
quotaDataDB := &QuotaData{}
|
||||
DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
|
||||
quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.CreatedAt).First(quotaDataDB)
|
||||
if quotaDataDB.Id > 0 {
|
||||
quotaDataDB.Count += quotaData.Count
|
||||
quotaDataDB.Quota += quotaData.Quota
|
||||
DB.Table("quota_data").Save(quotaDataDB)
|
||||
} else {
|
||||
DB.Table("quota_data").Create(quotaData)
|
||||
}
|
||||
}
|
||||
CacheQuotaData = make(map[string]*QuotaData)
|
||||
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
|
||||
}
|
||||
|
||||
func GetQuotaDataByUsername(username string, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
|
||||
var quotaDatas []*QuotaData
|
||||
// 从quota_data表中查询数据
|
||||
err = DB.Table("quota_data").Where("username = ?", username).Find("aDatas).Error
|
||||
return quotaDatas, err
|
||||
}
|
||||
|
||||
func GetQuotaDataByUserId(userId int, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
|
||||
var quotaDatas []*QuotaData
|
||||
// 从quota_data表中查询数据
|
||||
err = DB.Table("quota_data").Where("user_id = ? and created_at >= ? and created_at <= ?", userId, startTime, endTime).Find("aDatas).Error
|
||||
return quotaDatas, err
|
||||
}
|
||||
|
||||
func GetAllQuotaDates(startTime int64, endTime int64, username string) (quotaData []*QuotaData, err error) {
|
||||
if username != "" {
|
||||
return GetQuotaDataByUsername(username, startTime, endTime)
|
||||
}
|
||||
var quotaDatas []*QuotaData
|
||||
// 从quota_data表中查询数据
|
||||
// only select model_name, sum(count) as count, sum(quota) as quota, model_name, created_at from quota_data group by model_name, created_at;
|
||||
//err = DB.Table("quota_data").Where("created_at >= ? and created_at <= ?", startTime, endTime).Find("aDatas).Error
|
||||
err = DB.Table("quota_data").Select("model_name, sum(count) as count, sum(quota) as quota, created_at").Where("created_at >= ? and created_at <= ?", startTime, endTime).Group("model_name, created_at").Find("aDatas).Error
|
||||
return quotaDatas, err
|
||||
}
|
||||
@@ -6,31 +6,57 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||
// Otherwise, the sensitive information will be saved on local storage in plain text!
|
||||
type User struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
var user User
|
||||
|
||||
// err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
|
||||
// check email if empty
|
||||
var err error
|
||||
if email == "" {
|
||||
err = DB.Unscoped().First(&user, "username = ?", username).Error
|
||||
} else {
|
||||
err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// not exist, return false, nil
|
||||
return false, nil
|
||||
}
|
||||
// other error, return false, err
|
||||
return false, err
|
||||
}
|
||||
// exist, return true, nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetMaxUserId() int {
|
||||
@@ -40,7 +66,7 @@ func GetMaxUserId() int {
|
||||
}
|
||||
|
||||
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
||||
err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
@@ -80,6 +106,14 @@ func DeleteUserById(id int) (err error) {
|
||||
return user.Delete()
|
||||
}
|
||||
|
||||
func HardDeleteUserById(id int) error {
|
||||
if id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func inviteUser(inviterId int) (err error) {
|
||||
user, err := GetUserById(inviterId, true)
|
||||
if err != nil {
|
||||
@@ -169,9 +203,13 @@ func (user *User) Update(updatePassword bool) error {
|
||||
}
|
||||
}
|
||||
newUser := *user
|
||||
|
||||
DB.First(&user, user.Id)
|
||||
err = DB.Model(user).Updates(newUser).Error
|
||||
if err == nil {
|
||||
if common.RedisEnabled {
|
||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -183,6 +221,14 @@ func (user *User) Delete() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) HardDelete() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Unscoped().Delete(user).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateAndFill check password & user status
|
||||
func (user *User) ValidateAndFill() (err error) {
|
||||
// When querying with struct, GORM will only query with non-zero fields,
|
||||
|
||||
@@ -114,6 +114,10 @@ func SetApiRouter(router *gin.Engine) {
|
||||
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
|
||||
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
|
||||
|
||||
dataRoute := apiRouter.Group("/data")
|
||||
dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
|
||||
dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
|
||||
|
||||
logRoute.Use(middleware.CORS())
|
||||
{
|
||||
logRoute.GET("/token", controller.GetLogByKey)
|
||||
|
||||
@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relayV1Router.POST("/completions", controller.Relay)
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
@@ -49,10 +49,12 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
{
|
||||
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/notify", controller.RelayMidjourney)
|
||||
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
|
||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||
}
|
||||
//relayMjRouter.Use()
|
||||
}
|
||||
|
||||
@@ -3,13 +3,17 @@
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"@douyinfe/semi-ui": "^2.45.2",
|
||||
"@douyinfe/semi-ui": "^2.46.1",
|
||||
"@visactor/vchart": "~1.7.2",
|
||||
"@visactor/react-vchart": "~1.7.2",
|
||||
"@visactor/vchart-semi-theme": "~1.7.2",
|
||||
"axios": "^0.27.2",
|
||||
"history": "^5.3.0",
|
||||
"marked": "^4.1.1",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.2.3",
|
||||
"react-fireworks": "^1.0.4",
|
||||
"react-router-dom": "^6.3.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"react-toastify": "^9.0.8",
|
||||
@@ -43,7 +47,8 @@
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"prettier": "^2.7.1"
|
||||
"prettier": "^2.7.1",
|
||||
"typescript": "4.4.2"
|
||||
},
|
||||
"prettier": {
|
||||
"singleQuote": true,
|
||||
|
||||
@@ -23,10 +23,10 @@ import Log from './pages/Log';
|
||||
import Chat from './pages/Chat';
|
||||
import {Layout} from "@douyinfe/semi-ui";
|
||||
import Midjourney from "./pages/Midjourney";
|
||||
import Detail from "./pages/Detail";
|
||||
|
||||
const Home = lazy(() => import('./pages/Home'));
|
||||
const About = lazy(() => import('./pages/About'));
|
||||
|
||||
function App() {
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
const [statusState, statusDispatch] = useContext(StatusContext);
|
||||
@@ -49,6 +49,8 @@ function App() {
|
||||
localStorage.setItem('footer_html', data.footer_html);
|
||||
localStorage.setItem('quota_per_unit', data.quota_per_unit);
|
||||
localStorage.setItem('display_in_currency', data.display_in_currency);
|
||||
localStorage.setItem('enable_drawing', data.enable_drawing);
|
||||
localStorage.setItem('enable_data_export', data.enable_data_export);
|
||||
if (data.chat_link) {
|
||||
localStorage.setItem('chat_link', data.chat_link);
|
||||
} else {
|
||||
@@ -228,6 +230,14 @@ function App() {
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/detail'
|
||||
element={
|
||||
<PrivateRoute>
|
||||
<Detail />
|
||||
</PrivateRoute>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/midjourney'
|
||||
element={
|
||||
|
||||
@@ -163,11 +163,30 @@ const ChannelsTable = () => {
|
||||
<div>
|
||||
<InputNumber
|
||||
style={{width: 70}}
|
||||
name='name'
|
||||
name='priority'
|
||||
onChange={value => {
|
||||
manageChannel(record.id, 'priority', record, value);
|
||||
}}
|
||||
defaultValue={record.priority}
|
||||
min={-999}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '权重',
|
||||
dataIndex: 'weight',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
<InputNumber
|
||||
style={{width: 70}}
|
||||
name='weight'
|
||||
onChange={value => {
|
||||
manageChannel(record.id, 'weight', record, value);
|
||||
}}
|
||||
defaultValue={record.weight}
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, {useContext, useEffect, useState} from 'react';
|
||||
import React, {useContext, useEffect, useRef, useState} from 'react';
|
||||
import {Link, useNavigate} from 'react-router-dom';
|
||||
import {UserContext} from '../context/User';
|
||||
|
||||
@@ -6,18 +6,12 @@ import {Button, Container, Icon, Menu, Segment} from 'semantic-ui-react';
|
||||
import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers';
|
||||
import '../index.css';
|
||||
|
||||
import fireworks from 'react-fireworks';
|
||||
|
||||
import {
|
||||
IconAt,
|
||||
IconHistogram,
|
||||
IconGift,
|
||||
IconKey,
|
||||
IconUser,
|
||||
IconLayers,
|
||||
IconHelpCircle,
|
||||
IconCreditCard,
|
||||
IconSemiLogo,
|
||||
IconHome,
|
||||
IconImage
|
||||
IconHelpCircle
|
||||
} from '@douyinfe/semi-icons';
|
||||
import {Nav, Avatar, Dropdown, Layout, Switch} from '@douyinfe/semi-ui';
|
||||
import {stringToColor} from "../helpers/render";
|
||||
@@ -49,6 +43,8 @@ const HeaderBar = () => {
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
var themeMode = localStorage.getItem('theme-mode');
|
||||
const currentDate = new Date();
|
||||
const isNewYear = currentDate.getMonth() === 0 && currentDate.getDate() === 1;
|
||||
|
||||
async function logout() {
|
||||
setShowSidebar(false);
|
||||
@@ -59,10 +55,24 @@ const HeaderBar = () => {
|
||||
navigate('/login');
|
||||
}
|
||||
|
||||
const handleNewYearClick = () => {
|
||||
fireworks.init("root",{});
|
||||
fireworks.start();
|
||||
setTimeout(() => {
|
||||
fireworks.stop();
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 10000);
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (themeMode === 'dark') {
|
||||
switchMode(true);
|
||||
}
|
||||
if (isNewYear) {
|
||||
console.log('Happy New Year!');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const switchMode = (model) => {
|
||||
@@ -105,6 +115,19 @@ const HeaderBar = () => {
|
||||
}}
|
||||
footer={
|
||||
<>
|
||||
{isNewYear &&
|
||||
// happy new year
|
||||
<Dropdown
|
||||
position="bottomRight"
|
||||
render={
|
||||
<Dropdown.Menu>
|
||||
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Nav.Item itemKey={'new-year'} text={'🏮'}/>
|
||||
</Dropdown>
|
||||
}
|
||||
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
|
||||
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
|
||||
{userState.user ?
|
||||
|
||||
@@ -287,7 +287,7 @@ const LogsTable = () => {
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
console.log(logCount);
|
||||
// console.log(logCount);
|
||||
}
|
||||
|
||||
const loadLogs = async (startIdx) => {
|
||||
@@ -422,7 +422,6 @@ const LogsTable = () => {
|
||||
value={end_timestamp} type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')}/>
|
||||
{/*<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>*/}
|
||||
{
|
||||
isAdminUser && <>
|
||||
<Form.Input field="channel" label='渠道 ID' style={{width: 176}} value={channel}
|
||||
|
||||
@@ -307,7 +307,7 @@ const LogsTable = () => {
|
||||
// data.key = '' + data.id
|
||||
setLogs(logs);
|
||||
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||
console.log(logCount);
|
||||
// console.log(logCount);
|
||||
}
|
||||
|
||||
const loadLogs = async (startIdx) => {
|
||||
|
||||
@@ -3,13 +3,15 @@ import {Divider, Form, Grid, Header} from 'semantic-ui-react';
|
||||
import {API, showError, showSuccess, timestamp2string, verifyJSON} from '../helpers';
|
||||
|
||||
const OperationSetting = () => {
|
||||
let now = new Date();let [inputs, setInputs] = useState({
|
||||
let now = new Date();
|
||||
let [inputs, setInputs] = useState({
|
||||
QuotaForNewUser: 0,
|
||||
QuotaForInviter: 0,
|
||||
QuotaForInvitee: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
PreConsumedQuota: 0,
|
||||
ModelRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
TopUpLink: '',
|
||||
ChatLink: '',
|
||||
@@ -19,28 +21,32 @@ const OperationSetting = () => {
|
||||
LogConsumeEnabled: '',
|
||||
DisplayInCurrencyEnabled: '',
|
||||
DisplayTokenStatEnabled: '',
|
||||
DrawingEnabled: '',
|
||||
DataExportEnabled: '',
|
||||
DataExportInterval: 5,
|
||||
RetryTimes: 0
|
||||
});
|
||||
const [originInputs, setOriginInputs] = useState({});
|
||||
let [loading, setLoading] = useState(false);let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||
let [loading, setLoading] = useState(false);
|
||||
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'ModelPrice') {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
newInputs[item.key] = item.value;
|
||||
});
|
||||
setInputs(newInputs);
|
||||
setOriginInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
newInputs[item.key] = item.value;
|
||||
});
|
||||
setInputs(newInputs);
|
||||
setOriginInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
getOptions().then();
|
||||
@@ -65,80 +71,88 @@ const OperationSetting = () => {
|
||||
};
|
||||
|
||||
const handleInputChange = async (e, {name, value}) => {
|
||||
if (name.endsWith('Enabled')) {
|
||||
if (name.endsWith('Enabled') || name === 'DataExportInterval') {
|
||||
await updateOption(name, value);
|
||||
} else {
|
||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||
}
|
||||
};
|
||||
|
||||
const submitConfig = async (group) => {
|
||||
switch (group) {
|
||||
case 'monitor':
|
||||
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
|
||||
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
|
||||
const submitConfig = async (group) => {
|
||||
switch (group) {
|
||||
case 'monitor':
|
||||
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
|
||||
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
|
||||
}
|
||||
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
|
||||
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
|
||||
}
|
||||
break;
|
||||
case 'ratio':
|
||||
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
|
||||
if (!verifyJSON(inputs.ModelRatio)) {
|
||||
showError('模型倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelRatio', inputs.ModelRatio);
|
||||
}
|
||||
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
|
||||
if (!verifyJSON(inputs.GroupRatio)) {
|
||||
showError('分组倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('GroupRatio', inputs.GroupRatio);
|
||||
}
|
||||
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
|
||||
if (!verifyJSON(inputs.ModelPrice)) {
|
||||
showError('模型固定价格不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelPrice', inputs.ModelPrice);
|
||||
}
|
||||
break;
|
||||
case 'quota':
|
||||
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
|
||||
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
|
||||
}
|
||||
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
|
||||
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
|
||||
}
|
||||
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
|
||||
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
|
||||
}
|
||||
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
|
||||
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
|
||||
}
|
||||
break;
|
||||
case 'general':
|
||||
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
|
||||
await updateOption('TopUpLink', inputs.TopUpLink);
|
||||
}
|
||||
if (originInputs['ChatLink'] !== inputs.ChatLink) {
|
||||
await updateOption('ChatLink', inputs.ChatLink);
|
||||
}
|
||||
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
|
||||
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
|
||||
}
|
||||
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
||||
await updateOption('RetryTimes', inputs.RetryTimes);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
|
||||
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
|
||||
}
|
||||
break;
|
||||
case 'ratio':
|
||||
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
|
||||
if (!verifyJSON(inputs.ModelRatio)) {
|
||||
showError('模型倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelRatio', inputs.ModelRatio);
|
||||
}
|
||||
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
|
||||
if (!verifyJSON(inputs.GroupRatio)) {
|
||||
showError('分组倍率不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('GroupRatio', inputs.GroupRatio);
|
||||
}
|
||||
break;
|
||||
case 'quota':
|
||||
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
|
||||
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
|
||||
}
|
||||
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
|
||||
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
|
||||
}
|
||||
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
|
||||
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
|
||||
}
|
||||
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
|
||||
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
|
||||
}
|
||||
break;
|
||||
case 'general':
|
||||
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
|
||||
await updateOption('TopUpLink', inputs.TopUpLink);
|
||||
}
|
||||
if (originInputs['ChatLink'] !== inputs.ChatLink) {
|
||||
await updateOption('ChatLink', inputs.ChatLink);
|
||||
}
|
||||
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
|
||||
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
|
||||
}
|
||||
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
|
||||
await updateOption('RetryTimes', inputs.RetryTimes);
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
const deleteHistoryLogs = async () => {
|
||||
console.log(inputs);
|
||||
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
showSuccess(`${data} 条日志已清理!`);
|
||||
return;
|
||||
}
|
||||
showError('日志清理失败:' + message);
|
||||
};return (
|
||||
console.log(inputs);
|
||||
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
showSuccess(`${data} 条日志已清理!`);
|
||||
return;
|
||||
}
|
||||
showError('日志清理失败:' + message);
|
||||
};
|
||||
return (
|
||||
<Grid columns={1}>
|
||||
<Grid.Column>
|
||||
<Form loading={loading}>
|
||||
@@ -200,31 +214,58 @@ const OperationSetting = () => {
|
||||
name='DisplayTokenStatEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DrawingEnabled === 'true'}
|
||||
label='启用绘图功能'
|
||||
name='DrawingEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
submitConfig('general').then();
|
||||
}}>保存通用设置</Form.Button><Divider />
|
||||
<Header as='h3'>
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||
name='history_timestamp'
|
||||
onChange={(e, { name, value }) => {
|
||||
setHistoryTimestamp(value);
|
||||
}} />
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
deleteHistoryLogs().then();
|
||||
}}>清理历史日志</Form.Button>
|
||||
}}>保存通用设置</Form.Button><Divider/>
|
||||
<Header as='h3'>
|
||||
日志设置
|
||||
</Header>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.LogConsumeEnabled === 'true'}
|
||||
label='启用额度消费日志记录'
|
||||
name='LogConsumeEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
|
||||
</Form.Group>
|
||||
<Form.Group inline>
|
||||
<Form.Checkbox
|
||||
checked={inputs.DataExportEnabled === 'true'}
|
||||
label='启用数据看板(实验性)'
|
||||
name='DataExportEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Input
|
||||
label='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
|
||||
name='DataExportInterval'
|
||||
type={'number'}
|
||||
step='1'
|
||||
min='1'
|
||||
onChange={handleInputChange}
|
||||
autoComplete='new-password'
|
||||
value={inputs.DataExportInterval}
|
||||
placeholder='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider/>
|
||||
<Form.Group widths={4}>
|
||||
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
|
||||
name='history_timestamp'
|
||||
onChange={(e, {name, value}) => {
|
||||
setHistoryTimestamp(value);
|
||||
}}/>
|
||||
</Form.Group>
|
||||
<Form.Button onClick={() => {
|
||||
deleteHistoryLogs().then();
|
||||
}}>清理历史日志</Form.Button>
|
||||
<Divider/>
|
||||
<Header as='h3'>
|
||||
监控设置
|
||||
@@ -315,6 +356,17 @@ const OperationSetting = () => {
|
||||
<Header as='h3'>
|
||||
倍率设置
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
|
||||
name='ModelPrice'
|
||||
onChange={handleInputChange}
|
||||
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ModelPrice}
|
||||
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1,一次消耗0.1刀'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型倍率'
|
||||
|
||||
@@ -394,7 +394,7 @@ const RedemptionsTable = () => {
|
||||
}
|
||||
let keys = "";
|
||||
for (let i = 0; i < selectedKeys.length; i++) {
|
||||
keys += selectedKeys[i].name + " sk-" + selectedKeys[i].key + "\n";
|
||||
keys += selectedKeys[i].name + " " + selectedKeys[i].key + "\n";
|
||||
}
|
||||
await copyText(keys);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, {useContext, useState} from 'react';
|
||||
import React, {useContext, useMemo, useState} from 'react';
|
||||
import {Link, useNavigate} from 'react-router-dom';
|
||||
import {UserContext} from '../context/User';
|
||||
|
||||
@@ -6,7 +6,7 @@ import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../he
|
||||
import '../index.css';
|
||||
|
||||
import {
|
||||
IconAt,
|
||||
IconCalendarClock,
|
||||
IconHistogram,
|
||||
IconGift,
|
||||
IconKey,
|
||||
@@ -21,78 +21,6 @@ import {
|
||||
import {Nav, Avatar, Dropdown, Layout} from '@douyinfe/semi-ui';
|
||||
|
||||
// HeaderBar Buttons
|
||||
let headerButtons = [
|
||||
{
|
||||
text: '首页',
|
||||
itemKey: 'home',
|
||||
to: '/',
|
||||
icon: <IconHome/>
|
||||
},
|
||||
{
|
||||
text: '渠道',
|
||||
itemKey: 'channel',
|
||||
to: '/channel',
|
||||
icon: <IconLayers/>,
|
||||
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '聊天',
|
||||
itemKey: 'chat',
|
||||
to: '/chat',
|
||||
icon: <IconComment />,
|
||||
className: localStorage.getItem('chat_link')?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '令牌',
|
||||
itemKey: 'token',
|
||||
to: '/token',
|
||||
icon: <IconKey/>
|
||||
},
|
||||
{
|
||||
text: '兑换码',
|
||||
itemKey: 'redemption',
|
||||
to: '/redemption',
|
||||
icon: <IconGift/>,
|
||||
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '钱包',
|
||||
itemKey: 'topup',
|
||||
to: '/topup',
|
||||
icon: <IconCreditCard/>
|
||||
},
|
||||
{
|
||||
text: '用户管理',
|
||||
itemKey: 'user',
|
||||
to: '/user',
|
||||
icon: <IconUser/>,
|
||||
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '日志',
|
||||
itemKey: 'log',
|
||||
to: '/log',
|
||||
icon: <IconHistogram/>
|
||||
},
|
||||
{
|
||||
text: '绘图',
|
||||
itemKey: 'midjourney',
|
||||
to: '/midjourney',
|
||||
icon: <IconImage/>
|
||||
},
|
||||
{
|
||||
text: '设置',
|
||||
itemKey: 'setting',
|
||||
to: '/setting',
|
||||
icon: <IconSetting/>
|
||||
},
|
||||
// {
|
||||
// text: '关于',
|
||||
// itemKey: 'about',
|
||||
// to: '/about',
|
||||
// icon: <IconAt/>
|
||||
// }
|
||||
];
|
||||
|
||||
const SiderBar = () => {
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
@@ -101,6 +29,87 @@ const SiderBar = () => {
|
||||
const [showSidebar, setShowSidebar] = useState(false);
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
const headerButtons = useMemo(() => [
|
||||
{
|
||||
text: '首页',
|
||||
itemKey: 'home',
|
||||
to: '/',
|
||||
icon: <IconHome/>
|
||||
},
|
||||
{
|
||||
text: '渠道',
|
||||
itemKey: 'channel',
|
||||
to: '/channel',
|
||||
icon: <IconLayers/>,
|
||||
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '聊天',
|
||||
itemKey: 'chat',
|
||||
to: '/chat',
|
||||
icon: <IconComment />,
|
||||
className: localStorage.getItem('chat_link')?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '令牌',
|
||||
itemKey: 'token',
|
||||
to: '/token',
|
||||
icon: <IconKey/>
|
||||
},
|
||||
{
|
||||
text: '兑换码',
|
||||
itemKey: 'redemption',
|
||||
to: '/redemption',
|
||||
icon: <IconGift/>,
|
||||
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '钱包',
|
||||
itemKey: 'topup',
|
||||
to: '/topup',
|
||||
icon: <IconCreditCard/>
|
||||
},
|
||||
{
|
||||
text: '用户管理',
|
||||
itemKey: 'user',
|
||||
to: '/user',
|
||||
icon: <IconUser/>,
|
||||
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '日志',
|
||||
itemKey: 'log',
|
||||
to: '/log',
|
||||
icon: <IconHistogram/>
|
||||
},
|
||||
{
|
||||
text: '数据看版',
|
||||
itemKey: 'detail',
|
||||
to: '/detail',
|
||||
icon: <IconCalendarClock />,
|
||||
className: localStorage.getItem('enable_data_export') === 'true'?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '绘图',
|
||||
itemKey: 'midjourney',
|
||||
to: '/midjourney',
|
||||
icon: <IconImage/>,
|
||||
className: localStorage.getItem('enable_drawing') === 'true'?'semi-navigation-item-normal':'tableHiddle',
|
||||
},
|
||||
{
|
||||
text: '设置',
|
||||
itemKey: 'setting',
|
||||
to: '/setting',
|
||||
icon: <IconSetting/>
|
||||
},
|
||||
// {
|
||||
// text: '关于',
|
||||
// itemKey: 'about',
|
||||
// to: '/about',
|
||||
// icon: <IconAt/>
|
||||
// }
|
||||
], [localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), localStorage.getItem('chat_link'), isAdmin()]);
|
||||
|
||||
|
||||
async function logout() {
|
||||
setShowSidebar(false);
|
||||
@@ -133,6 +142,7 @@ const SiderBar = () => {
|
||||
setting: "/setting",
|
||||
about: "/about",
|
||||
chat: "/chat",
|
||||
detail: "/detail",
|
||||
};
|
||||
return (
|
||||
<Link
|
||||
|
||||
@@ -43,10 +43,14 @@ function renderTimestamp(timestamp) {
|
||||
);
|
||||
}
|
||||
|
||||
function renderStatus(status) {
|
||||
function renderStatus(status, model_limits_enabled = false) {
|
||||
switch (status) {
|
||||
case 1:
|
||||
return <Tag color='green' size='large'>已启用</Tag>;
|
||||
if (model_limits_enabled) {
|
||||
return <Tag color='green' size='large'>已启用:限制模型</Tag>;
|
||||
} else {
|
||||
return <Tag color='green' size='large'>已启用</Tag>;
|
||||
}
|
||||
case 2:
|
||||
return <Tag color='red' size='large'> 已禁用 </Tag>;
|
||||
case 3:
|
||||
@@ -78,7 +82,7 @@ const TokensTable = () => {
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
{renderStatus(text)}
|
||||
{renderStatus(text, record.model_limits_enabled)}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
@@ -224,6 +228,11 @@ const TokensTable = () => {
|
||||
|
||||
const closeEdit = () => {
|
||||
setShowEdit(false);
|
||||
setTimeout(() => {
|
||||
setEditingToken({
|
||||
id: undefined,
|
||||
});
|
||||
}, 500);
|
||||
}
|
||||
|
||||
const setTokensFormat = (tokens) => {
|
||||
|
||||
@@ -72,32 +72,49 @@ const UsersTable = () => {
|
||||
}, {
|
||||
title: '状态', dataIndex: 'status', render: (text, record, index) => {
|
||||
return (<div>
|
||||
{renderStatus(text)}
|
||||
{record.DeletedAt !== null? <Tag color='red'>已注销</Tag> : renderStatus(text)}
|
||||
</div>);
|
||||
},
|
||||
}, {
|
||||
title: '', dataIndex: 'operate', render: (text, record, index) => (<div>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'promote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
|
||||
</Popconfirm>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'demote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
|
||||
</Popconfirm>
|
||||
{
|
||||
record.DeletedAt !== null ? <></>:
|
||||
<>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'promote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
|
||||
</Popconfirm>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'demote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
|
||||
</Popconfirm>
|
||||
{record.status === 1 ?
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'disable', record)
|
||||
}}>禁用</Button> :
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'enable', record);
|
||||
}} disabled={record.status === 3}>启用</Button>}
|
||||
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
|
||||
setEditingUser(record);
|
||||
setShowEditUser(true);
|
||||
}}>编辑</Button>
|
||||
</>
|
||||
|
||||
}
|
||||
<Popconfirm
|
||||
title="确定是否要删除此用户?"
|
||||
content="此修改将不可逆"
|
||||
content="硬删除,此修改将不可逆"
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
@@ -108,17 +125,6 @@ const UsersTable = () => {
|
||||
>
|
||||
<Button theme='light' type='danger' style={{marginRight: 1}}>删除</Button>
|
||||
</Popconfirm>
|
||||
{record.status === 1 ?
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'disable', record)
|
||||
}}>禁用</Button> :
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'enable', record);
|
||||
}} disabled={record.status === 3}>启用</Button>}
|
||||
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
|
||||
setEditingUser(record);
|
||||
setShowEditUser(true);
|
||||
}}>编辑</Button>
|
||||
</div>),
|
||||
},];
|
||||
|
||||
|
||||
@@ -1,114 +1,129 @@
|
||||
import { Label } from 'semantic-ui-react';
|
||||
import {Label} from 'semantic-ui-react';
|
||||
import {Tag} from "@douyinfe/semi-ui";
|
||||
|
||||
export function renderText(text, limit) {
|
||||
if (text.length > limit) {
|
||||
return text.slice(0, limit - 3) + '...';
|
||||
}
|
||||
return text;
|
||||
if (text.length > limit) {
|
||||
return text.slice(0, limit - 3) + '...';
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
export function renderGroup(group) {
|
||||
if (group === '') {
|
||||
return <Tag size='large'>default</Tag>;
|
||||
}
|
||||
let groups = group.split(',');
|
||||
groups.sort();
|
||||
return <>
|
||||
{groups.map((group) => {
|
||||
if (group === 'vip' || group === 'pro') {
|
||||
return <Tag size='large' color='yellow'>{group}</Tag>;
|
||||
} else if (group === 'svip' || group === 'premium') {
|
||||
return <Tag size='large' color='red'>{group}</Tag>;
|
||||
}
|
||||
if (group === 'default') {
|
||||
return <Tag size='large'>{group}</Tag>;
|
||||
} else {
|
||||
return <Tag size='large' color={stringToColor(group)}>{group}</Tag>;
|
||||
}
|
||||
})}
|
||||
</>;
|
||||
if (group === '') {
|
||||
return <Tag size='large'>default</Tag>;
|
||||
}
|
||||
let groups = group.split(',');
|
||||
groups.sort();
|
||||
return <>
|
||||
{groups.map((group) => {
|
||||
if (group === 'vip' || group === 'pro') {
|
||||
return <Tag size='large' color='yellow'>{group}</Tag>;
|
||||
} else if (group === 'svip' || group === 'premium') {
|
||||
return <Tag size='large' color='red'>{group}</Tag>;
|
||||
}
|
||||
if (group === 'default') {
|
||||
return <Tag size='large'>{group}</Tag>;
|
||||
} else {
|
||||
return <Tag size='large' color={stringToColor(group)}>{group}</Tag>;
|
||||
}
|
||||
})}
|
||||
</>;
|
||||
}
|
||||
|
||||
export function renderNumber(num) {
|
||||
if (num >= 1000000000) {
|
||||
return (num / 1000000000).toFixed(1) + 'B';
|
||||
} else if (num >= 1000000) {
|
||||
return (num / 1000000).toFixed(1) + 'M';
|
||||
} else if (num >= 10000) {
|
||||
return (num / 1000).toFixed(1) + 'k';
|
||||
} else {
|
||||
if (num >= 1000000000) {
|
||||
return (num / 1000000000).toFixed(1) + 'B';
|
||||
} else if (num >= 1000000) {
|
||||
return (num / 1000000).toFixed(1) + 'M';
|
||||
} else if (num >= 10000) {
|
||||
return (num / 1000).toFixed(1) + 'k';
|
||||
} else {
|
||||
return num;
|
||||
}
|
||||
}
|
||||
|
||||
export function renderQuotaNumberWithDigit(num, digits = 2) {
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
num = num.toFixed(digits);
|
||||
if (displayInCurrency) {
|
||||
return '$' + num;
|
||||
}
|
||||
return num;
|
||||
}
|
||||
}
|
||||
|
||||
export function renderNumberWithPoint(num) {
|
||||
num = num.toFixed(2);
|
||||
if (num >= 100000) {
|
||||
// Convert number to string to manipulate it
|
||||
let numStr = num.toString();
|
||||
// Find the position of the decimal point
|
||||
let decimalPointIndex = numStr.indexOf('.');
|
||||
num = num.toFixed(2);
|
||||
if (num >= 100000) {
|
||||
// Convert number to string to manipulate it
|
||||
let numStr = num.toString();
|
||||
// Find the position of the decimal point
|
||||
let decimalPointIndex = numStr.indexOf('.');
|
||||
|
||||
let wholePart = numStr;
|
||||
let decimalPart = '';
|
||||
let wholePart = numStr;
|
||||
let decimalPart = '';
|
||||
|
||||
// If there is a decimal point, split the number into whole and decimal parts
|
||||
if (decimalPointIndex !== -1) {
|
||||
wholePart = numStr.slice(0, decimalPointIndex);
|
||||
decimalPart = numStr.slice(decimalPointIndex);
|
||||
// If there is a decimal point, split the number into whole and decimal parts
|
||||
if (decimalPointIndex !== -1) {
|
||||
wholePart = numStr.slice(0, decimalPointIndex);
|
||||
decimalPart = numStr.slice(decimalPointIndex);
|
||||
}
|
||||
|
||||
// Take the first two and last two digits of the whole number part
|
||||
let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2);
|
||||
|
||||
// Return the formatted number
|
||||
return shortenedWholePart + decimalPart;
|
||||
}
|
||||
|
||||
// Take the first two and last two digits of the whole number part
|
||||
let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2);
|
||||
|
||||
// Return the formatted number
|
||||
return shortenedWholePart + decimalPart;
|
||||
}
|
||||
|
||||
// If the number is less than 100,000, return it unmodified
|
||||
return num;
|
||||
// If the number is less than 100,000, return it unmodified
|
||||
return num;
|
||||
}
|
||||
|
||||
export function getQuotaPerUnit() {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
return quotaPerUnit;
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
return quotaPerUnit;
|
||||
}
|
||||
|
||||
export function getQuotaWithUnit(quota, digits = 6) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
return (quota / quotaPerUnit).toFixed(digits);
|
||||
}
|
||||
|
||||
export function renderQuota(quota, digits = 2) {
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return '$' + (quota / quotaPerUnit).toFixed(digits);
|
||||
}
|
||||
return renderNumber(quota);
|
||||
let quotaPerUnit = localStorage.getItem('quota_per_unit');
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
quotaPerUnit = parseFloat(quotaPerUnit);
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return '$' + (quota / quotaPerUnit).toFixed(digits);
|
||||
}
|
||||
return renderNumber(quota);
|
||||
}
|
||||
|
||||
export function renderQuotaWithPrompt(quota, digits) {
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return `(等价金额:${renderQuota(quota, digits)})`;
|
||||
}
|
||||
return '';
|
||||
let displayInCurrency = localStorage.getItem('display_in_currency');
|
||||
displayInCurrency = displayInCurrency === 'true';
|
||||
if (displayInCurrency) {
|
||||
return `(等价金额:${renderQuota(quota, digits)})`;
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
|
||||
'light-blue', 'lime', 'orange', 'pink',
|
||||
'purple', 'red', 'teal', 'violet', 'yellow'
|
||||
'light-blue', 'lime', 'orange', 'pink',
|
||||
'purple', 'red', 'teal', 'violet', 'yellow'
|
||||
]
|
||||
|
||||
export function stringToColor(str) {
|
||||
let sum = 0;
|
||||
// 对字符串中的每个字符进行操作
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
// 将字符的ASCII值加到sum中
|
||||
sum += str.charCodeAt(i);
|
||||
}
|
||||
// 使用模运算得到个位数
|
||||
let i = sum % colors.length;
|
||||
return colors[i];
|
||||
let sum = 0;
|
||||
// 对字符串中的每个字符进行操作
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
// 将字符的ASCII值加到sum中
|
||||
sum += str.charCodeAt(i);
|
||||
}
|
||||
// 使用模运算得到个位数
|
||||
let i = sum % colors.length;
|
||||
return colors[i];
|
||||
}
|
||||
@@ -171,6 +171,32 @@ export function timestamp2string(timestamp) {
|
||||
);
|
||||
}
|
||||
|
||||
export function timestamp2string1(timestamp) {
|
||||
let date = new Date(timestamp * 1000);
|
||||
// let year = date.getFullYear().toString();
|
||||
let month = (date.getMonth() + 1).toString();
|
||||
let day = date.getDate().toString();
|
||||
let hour = date.getHours().toString();
|
||||
if (month.length === 1) {
|
||||
month = '0' + month;
|
||||
}
|
||||
if (day.length === 1) {
|
||||
day = '0' + day;
|
||||
}
|
||||
if (hour.length === 1) {
|
||||
hour = '0' + hour;
|
||||
}
|
||||
return (
|
||||
// year +
|
||||
// '-' +
|
||||
month +
|
||||
'-' +
|
||||
day +
|
||||
' ' +
|
||||
hour + ":00"
|
||||
);
|
||||
}
|
||||
|
||||
export function downloadTextAsFile(text, filename) {
|
||||
let blob = new Blob([text], { type: 'text/plain;charset=utf-8' });
|
||||
let url = URL.createObjectURL(blob);
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { initVChartSemiTheme } from '@visactor/vchart-semi-theme';
|
||||
import VChart from "@visactor/vchart";
|
||||
import React from 'react';
|
||||
import ReactDOM from 'react-dom/client';
|
||||
import {BrowserRouter} from 'react-router-dom';
|
||||
import {Container} from 'semantic-ui-react';
|
||||
import App from './App';
|
||||
import HeaderBar from './components/HeaderBar';
|
||||
import Footer from './components/Footer';
|
||||
@@ -14,6 +15,11 @@ import {StatusProvider} from './context/Status';
|
||||
import {Layout} from "@douyinfe/semi-ui";
|
||||
import SiderBar from "./components/SiderBar";
|
||||
|
||||
// initialization
|
||||
initVChartSemiTheme({
|
||||
isWatchingThemeSwitch: true,
|
||||
});
|
||||
|
||||
const root = ReactDOM.createRoot(document.getElementById('root'));
|
||||
const {Sider, Content, Header} = Layout;
|
||||
root.render(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Header, Segment } from 'semantic-ui-react';
|
||||
import { API, showError } from '../../helpers';
|
||||
import { marked } from 'marked';
|
||||
import {Layout} from "@douyinfe/semi-ui";
|
||||
|
||||
292
web/src/pages/Detail/index.js
Normal file
292
web/src/pages/Detail/index.js
Normal file
@@ -0,0 +1,292 @@
|
||||
import React, {useEffect, useRef, useState} from 'react';
|
||||
import {Button, Col, Form, Layout, Row, Spin} from "@douyinfe/semi-ui";
|
||||
import VChart from '@visactor/vchart';
|
||||
import {useEffectOnce} from "usehooks-ts";
|
||||
import {API, isAdmin, showError, timestamp2string, timestamp2string1} from "../../helpers";
|
||||
import {getQuotaWithUnit, renderNumber, renderQuotaNumberWithDigit} from "../../helpers/render";
|
||||
|
||||
const Detail = (props) => {
|
||||
|
||||
let now = new Date();
|
||||
const [inputs, setInputs] = useState({
|
||||
username: '',
|
||||
token_name: '',
|
||||
model_name: '',
|
||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
|
||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||
channel: ''
|
||||
});
|
||||
const {username, token_name, model_name, start_timestamp, end_timestamp, channel} = inputs;
|
||||
const isAdminUser = isAdmin();
|
||||
const initialized = useRef(false)
|
||||
const [modelDataChart, setModelDataChart] = useState(null);
|
||||
const [modelDataPieChart, setModelDataPieChart] = useState(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [quotaData, setQuotaData] = useState([]);
|
||||
const [quotaDataPie, setQuotaDataPie] = useState([]);
|
||||
const [quotaDataLine, setQuotaDataLine] = useState([]);
|
||||
|
||||
const handleInputChange = (value, name) => {
|
||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||
};
|
||||
|
||||
const spec_line = {
|
||||
type: 'bar',
|
||||
data: [
|
||||
{
|
||||
id: 'barData',
|
||||
values: [
|
||||
]
|
||||
}
|
||||
],
|
||||
xField: 'Time',
|
||||
yField: 'Usage',
|
||||
seriesField: 'Model',
|
||||
stack: true,
|
||||
legends: {
|
||||
visible: true
|
||||
},
|
||||
title: {
|
||||
visible: true,
|
||||
text: '模型消耗分布(小时)'
|
||||
},
|
||||
bar: {
|
||||
// The state style of bar
|
||||
state: {
|
||||
hover: {
|
||||
stroke: '#000',
|
||||
lineWidth: 1
|
||||
}
|
||||
}
|
||||
},
|
||||
tooltip: {
|
||||
mark: {
|
||||
content: [
|
||||
{
|
||||
key: datum => datum['Model'],
|
||||
value: datum => renderQuotaNumberWithDigit(datum['Usage'], 4)
|
||||
}
|
||||
]
|
||||
},
|
||||
dimension: {
|
||||
content: [
|
||||
{
|
||||
key: datum => datum['Model'],
|
||||
value: datum => datum['Usage']
|
||||
}
|
||||
],
|
||||
updateContent: array => {
|
||||
// sort by value
|
||||
array.sort((a, b) => b.value - a.value);
|
||||
// add $
|
||||
for (let i = 0; i < array.length; i++) {
|
||||
array[i].value = renderQuotaNumberWithDigit(array[i].value, 4);
|
||||
}
|
||||
return array;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const spec_pie = {
|
||||
type: 'pie',
|
||||
data: [
|
||||
{
|
||||
id: 'id0',
|
||||
values: [
|
||||
{ type: 'null', value: '0' },
|
||||
]
|
||||
}
|
||||
],
|
||||
outerRadius: 0.8,
|
||||
innerRadius: 0.5,
|
||||
padAngle: 0.6,
|
||||
valueField: 'value',
|
||||
categoryField: 'type',
|
||||
pie: {
|
||||
style: {
|
||||
cornerRadius: 10
|
||||
},
|
||||
state: {
|
||||
hover: {
|
||||
outerRadius: 0.85,
|
||||
stroke: '#000',
|
||||
lineWidth: 1
|
||||
},
|
||||
selected: {
|
||||
outerRadius: 0.85,
|
||||
stroke: '#000',
|
||||
lineWidth: 1
|
||||
}
|
||||
}
|
||||
},
|
||||
title: {
|
||||
visible: true,
|
||||
text: '模型调用次数占比'
|
||||
},
|
||||
legends: {
|
||||
visible: true,
|
||||
orient: 'left'
|
||||
},
|
||||
label: {
|
||||
visible: true
|
||||
},
|
||||
tooltip: {
|
||||
mark: {
|
||||
content: [
|
||||
{
|
||||
key: datum => datum['type'],
|
||||
value: datum => renderNumber(datum['value'])
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const loadQuotaData = async (lineChart, pieChart) => {
|
||||
setLoading(true);
|
||||
|
||||
let url = '';
|
||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||
if (isAdminUser) {
|
||||
url = `/api/data/?username=${username}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
} else {
|
||||
url = `/api/data/self/?start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||
}
|
||||
const res = await API.get(url);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
setQuotaData(data);
|
||||
if (data.length === 0) {
|
||||
data.push({
|
||||
'count': 0,
|
||||
'model_name': '无数据',
|
||||
'quota': 0,
|
||||
'created_at': now.getTime() / 1000
|
||||
})
|
||||
}
|
||||
updateChart(lineChart, pieChart, data);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
};
|
||||
|
||||
const refresh = async () => {
|
||||
await loadQuotaData(modelDataChart, modelDataPieChart);
|
||||
};
|
||||
|
||||
const initChart = async () => {
|
||||
let lineChart = modelDataChart
|
||||
if (!modelDataChart) {
|
||||
lineChart = new VChart(spec_line, {dom: 'model_data'});
|
||||
setModelDataChart(lineChart);
|
||||
lineChart.renderAsync();
|
||||
}
|
||||
let pieChart = modelDataPieChart
|
||||
if (!modelDataPieChart) {
|
||||
pieChart = new VChart(spec_pie, {dom: 'model_pie'});
|
||||
setModelDataPieChart(pieChart);
|
||||
pieChart.renderAsync();
|
||||
}
|
||||
console.log('init vchart');
|
||||
await loadQuotaData(lineChart, pieChart)
|
||||
}
|
||||
|
||||
const updateChart = (lineChart, pieChart, data) => {
|
||||
if (isAdminUser) {
|
||||
// 将所有用户合并
|
||||
}
|
||||
let pieData = [];
|
||||
let lineData = [];
|
||||
for (let i = 0; i < data.length; i++) {
|
||||
const item = data[i];
|
||||
// 合并model_name
|
||||
let pieItem = pieData.find(it => it.type === item.model_name);
|
||||
if (pieItem) {
|
||||
pieItem.value += item.count;
|
||||
} else {
|
||||
pieData.push({
|
||||
"type": item.model_name,
|
||||
"value": item.count
|
||||
});
|
||||
}
|
||||
// 合并created_at和model_name 为 lineData, created_at 数据类型是小时的时间戳
|
||||
// 转换日期格式
|
||||
let createTime = timestamp2string1(item.created_at);
|
||||
let lineItem = lineData.find(it => it.Time === createTime && it.Model === item.model_name);
|
||||
if (lineItem) {
|
||||
lineItem.Usage += parseFloat(getQuotaWithUnit(item.quota));
|
||||
} else {
|
||||
lineData.push({
|
||||
"Time": createTime,
|
||||
"Model": item.model_name,
|
||||
"Usage": parseFloat(getQuotaWithUnit(item.quota))
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
// sort by count
|
||||
pieData.sort((a, b) => b.value - a.value);
|
||||
pieChart.updateData('id0', pieData);
|
||||
lineChart.updateData('barData', lineData);
|
||||
pieChart.reLayout();
|
||||
lineChart.reLayout();
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (!initialized.current) {
|
||||
initialized.current = true;
|
||||
initChart();
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Layout>
|
||||
<Layout.Header>
|
||||
<h3>数据看板</h3>
|
||||
</Layout.Header>
|
||||
<Layout.Content>
|
||||
<Form layout='horizontal' style={{marginTop: 10}}>
|
||||
<>
|
||||
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
|
||||
initValue={start_timestamp}
|
||||
value={start_timestamp} type='dateTime'
|
||||
name='start_timestamp'
|
||||
onChange={value => handleInputChange(value, 'start_timestamp')}/>
|
||||
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
|
||||
initValue={end_timestamp}
|
||||
value={end_timestamp} type='dateTime'
|
||||
name='end_timestamp'
|
||||
onChange={value => handleInputChange(value, 'end_timestamp')}/>
|
||||
{
|
||||
isAdminUser && <>
|
||||
<Form.Input field="username" label='用户名称' style={{width: 176}} value={username}
|
||||
placeholder={'可选值'} name='username'
|
||||
onChange={value => handleInputChange(value, 'username')}/>
|
||||
</>
|
||||
}
|
||||
<Form.Section>
|
||||
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
|
||||
onClick={refresh} loading={loading}>查询</Button>
|
||||
</Form.Section>
|
||||
</>
|
||||
</Form>
|
||||
<Spin spinning={loading}>
|
||||
<div style={{height: 500}}>
|
||||
<div id="model_pie" style={{width: '100%', minWidth: 100}}></div>
|
||||
</div>
|
||||
<div style={{height: 500}}>
|
||||
<div id="model_data" style={{width: '100%', minWidth: 100}}></div>
|
||||
</div>
|
||||
</Spin>
|
||||
</Layout.Content>
|
||||
</Layout>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
export default Detail;
|
||||
@@ -2,22 +2,37 @@ import React, {useEffect, useRef, useState} from 'react';
|
||||
import {useParams, useNavigate} from 'react-router-dom';
|
||||
import {API, isMobile, showError, showSuccess, timestamp2string} from '../../helpers';
|
||||
import {renderQuota, renderQuotaWithPrompt} from '../../helpers/render';
|
||||
import {Layout, SideSheet, Button, Space, Spin, Banner, Input, DatePicker, AutoComplete, Typography} from "@douyinfe/semi-ui";
|
||||
import {
|
||||
Layout,
|
||||
SideSheet,
|
||||
Button,
|
||||
Space,
|
||||
Spin,
|
||||
Banner,
|
||||
Input,
|
||||
DatePicker,
|
||||
AutoComplete,
|
||||
Typography,
|
||||
Checkbox, Select
|
||||
} from "@douyinfe/semi-ui";
|
||||
import Title from "@douyinfe/semi-ui/lib/es/typography/title";
|
||||
import {Divider} from "semantic-ui-react";
|
||||
|
||||
const EditToken = (props) => {
|
||||
const isEdit = props.editingToken.id !== undefined;
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [loading, setLoading] = useState(isEdit);
|
||||
const originInputs = {
|
||||
name: '',
|
||||
remain_quota: isEdit ? 0 : 500000,
|
||||
expired_time: -1,
|
||||
unlimited_quota: false
|
||||
unlimited_quota: false,
|
||||
model_limits_enabled: false,
|
||||
model_limits: [],
|
||||
};
|
||||
const [inputs, setInputs] = useState(originInputs);
|
||||
const {name, remain_quota, expired_time, unlimited_quota} = inputs;
|
||||
const {name, remain_quota, expired_time, unlimited_quota, model_limits_enabled, model_limits} = inputs;
|
||||
// const [visible, setVisible] = useState(false);
|
||||
const [models, setModels] = useState({});
|
||||
const navigate = useNavigate();
|
||||
const handleInputChange = (name, value) => {
|
||||
setInputs((inputs) => ({...inputs, [name]: value}));
|
||||
@@ -44,6 +59,20 @@ const EditToken = (props) => {
|
||||
setInputs({...inputs, unlimited_quota: !unlimited_quota});
|
||||
};
|
||||
|
||||
const loadModels = async () => {
|
||||
let res = await API.get(`/api/user/models`);
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
let localModelOptions = data.map((model) => ({
|
||||
label: model,
|
||||
value: model
|
||||
}));
|
||||
setModels(localModelOptions);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
}
|
||||
|
||||
const loadToken = async () => {
|
||||
setLoading(true);
|
||||
let res = await API.get(`/api/token/${props.editingToken.id}`);
|
||||
@@ -52,6 +81,11 @@ const EditToken = (props) => {
|
||||
if (data.expired_time !== -1) {
|
||||
data.expired_time = timestamp2string(data.expired_time);
|
||||
}
|
||||
if (data.model_limits !== '') {
|
||||
data.model_limits = data.model_limits.split(',');
|
||||
} else {
|
||||
data.model_limits = [];
|
||||
}
|
||||
setInputs(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -59,17 +93,22 @@ const EditToken = (props) => {
|
||||
setLoading(false);
|
||||
};
|
||||
useEffect(() => {
|
||||
if (isEdit) {
|
||||
loadToken().then(
|
||||
() => {
|
||||
// console.log(inputs);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
setInputs(originInputs);
|
||||
}
|
||||
setIsEdit(props.editingToken.id !== undefined);
|
||||
}, [props.editingToken.id]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isEdit) {
|
||||
setInputs(originInputs);
|
||||
} else {
|
||||
loadToken().then(
|
||||
() => {
|
||||
// console.log(inputs);
|
||||
}
|
||||
);
|
||||
}
|
||||
loadModels();
|
||||
}, [isEdit]);
|
||||
|
||||
// 新增 state 变量 tokenCount 来记录用户想要创建的令牌数量,默认为 1
|
||||
const [tokenCount, setTokenCount] = useState(1);
|
||||
|
||||
@@ -107,7 +146,7 @@ const EditToken = (props) => {
|
||||
}
|
||||
localInputs.expired_time = Math.ceil(time / 1000);
|
||||
}
|
||||
|
||||
localInputs.model_limits = localInputs.model_limits.join(',');
|
||||
let res = await API.put(`/api/token/`, {...localInputs, id: parseInt(props.editingToken.id)});
|
||||
const {success, message} = res.data;
|
||||
if (success) {
|
||||
@@ -137,7 +176,7 @@ const EditToken = (props) => {
|
||||
}
|
||||
localInputs.expired_time = Math.ceil(time / 1000);
|
||||
}
|
||||
|
||||
localInputs.model_limits = localInputs.model_limits.join(',');
|
||||
let res = await API.post(`/api/token/`, localInputs);
|
||||
const {success, message} = res.data;
|
||||
|
||||
@@ -234,7 +273,7 @@ const EditToken = (props) => {
|
||||
value={remain_quota}
|
||||
autoComplete='new-password'
|
||||
type='number'
|
||||
position={'top'}
|
||||
// position={'top'}
|
||||
data={[
|
||||
{value: 500000, label: '1$'},
|
||||
{value: 5000000, label: '10$'},
|
||||
@@ -245,27 +284,30 @@ const EditToken = (props) => {
|
||||
]}
|
||||
disabled={unlimited_quota}
|
||||
/>
|
||||
<div style={{marginTop: 20}}>
|
||||
<Typography.Text>新建数量</Typography.Text>
|
||||
</div>
|
||||
|
||||
{!isEdit && (
|
||||
<AutoComplete
|
||||
style={{ marginTop: 8 }}
|
||||
label='数量'
|
||||
placeholder={'请选择或输入创建令牌的数量'}
|
||||
onChange={(value) => handleTokenCountChange(value)}
|
||||
onSelect={(value) => handleTokenCountChange(value)}
|
||||
value={tokenCount.toString()}
|
||||
autoComplete='off'
|
||||
type='number'
|
||||
data={[
|
||||
{ value: 10, label: '10个' },
|
||||
{ value: 20, label: '20个' },
|
||||
{ value: 30, label: '30个' },
|
||||
{ value: 100, label: '100个' },
|
||||
]}
|
||||
disabled={unlimited_quota}
|
||||
/>
|
||||
<>
|
||||
<div style={{marginTop: 20}}>
|
||||
<Typography.Text>新建数量</Typography.Text>
|
||||
</div>
|
||||
<AutoComplete
|
||||
style={{ marginTop: 8 }}
|
||||
label='数量'
|
||||
placeholder={'请选择或输入创建令牌的数量'}
|
||||
onChange={(value) => handleTokenCountChange(value)}
|
||||
onSelect={(value) => handleTokenCountChange(value)}
|
||||
value={tokenCount.toString()}
|
||||
autoComplete='off'
|
||||
type='number'
|
||||
data={[
|
||||
{ value: 10, label: '10个' },
|
||||
{ value: 20, label: '20个' },
|
||||
{ value: 30, label: '30个' },
|
||||
{ value: 100, label: '100个' },
|
||||
]}
|
||||
disabled={unlimited_quota}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div>
|
||||
@@ -273,6 +315,34 @@ const EditToken = (props) => {
|
||||
setUnlimitedQuota();
|
||||
}}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button>
|
||||
</div>
|
||||
<Divider/>
|
||||
<div style={{marginTop: 10, display: 'flex'}}>
|
||||
<Space>
|
||||
<Checkbox
|
||||
name='model_limits_enabled'
|
||||
checked={model_limits_enabled}
|
||||
onChange={(e) => handleInputChange('model_limits_enabled', e.target.checked)}
|
||||
>
|
||||
</Checkbox>
|
||||
<Typography.Text>启用模型限制(非必要,不建议启用)</Typography.Text>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
<Select
|
||||
style={{marginTop: 8}}
|
||||
placeholder={'请选择该渠道所支持的模型'}
|
||||
name='models'
|
||||
required
|
||||
multiple
|
||||
selection
|
||||
onChange={value => {
|
||||
handleInputChange('model_limits', value);
|
||||
}}
|
||||
value={inputs.model_limits}
|
||||
autoComplete='new-password'
|
||||
optionList={models}
|
||||
disabled={!model_limits_enabled}
|
||||
/>
|
||||
</Spin>
|
||||
</SideSheet>
|
||||
</>
|
||||
|
||||
@@ -131,7 +131,7 @@ const TopUp = () => {
|
||||
}, []);
|
||||
|
||||
const renderAmount = () => {
|
||||
console.log(amount);
|
||||
// console.log(amount);
|
||||
return amount + '元';
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user