Compare commits

...

75 Commits

Author SHA1 Message Date
Calcium-Ion
954ed893b9 Update LICENSE 2024-01-08 19:48:30 +08:00
CaIon
95e84f2bb1 perf: 优化数据看板性能和显示效果 2024-01-08 18:49:10 +08:00
CaIon
f3f8cdc4a3 update README.md 2024-01-08 16:26:00 +08:00
CaIon
1244963e81 feat: 可设置令牌能调用的模型 2024-01-08 16:25:17 +08:00
CaIon
8f36a995ef perf: 美化数据看板 2024-01-08 11:32:27 +08:00
CaIon
daf0be4915 fix: fix JSON tag in Log struct 2024-01-08 11:25:22 +08:00
CaIon
33b93e1e9a 删除前端无用log 2024-01-08 11:24:47 +08:00
CaIon
f3124e7252 fix: 减少username和model_name字段长度为64 2024-01-08 10:11:36 +08:00
CaIon
ebedf2cb7e fix: 减少group和model字段长度为64 2024-01-08 10:08:20 +08:00
Calcium-Ion
d68a50125b Update README.md 2024-01-08 00:25:12 +08:00
CaIon
18e9d37057 fix: 修复侧边导航栏需要刷新才出现选项的问题 2024-01-07 23:54:51 +08:00
CaIon
e2d994d73a fix: 修复mj固定价格设置无效的问题 2024-01-07 23:40:06 +08:00
CaIon
29dbdf01f0 fix: 数据面板次数统计错误 2024-01-07 23:28:27 +08:00
CaIon
64862cd634 fix: recover panic 2024-01-07 22:25:03 +08:00
CaIon
a5e0f86c65 fix: 索引长度 2024-01-07 22:16:20 +08:00
CaIon
2a995a5da2 fix: 索引名称长度 2024-01-07 22:07:17 +08:00
CaIon
bba6174745 fix: fix gemini panic 2024-01-07 21:27:28 +08:00
CaIon
b91a269ddb fix: 修复数据看板饼图显示错误 2024-01-07 20:48:09 +08:00
CaIon
1c2bba8979 feat: 完善数据看板选择时间区间 2024-01-07 19:47:35 +08:00
CaIon
ce05e7dd86 fix: 优化数据看板更新逻辑 2024-01-07 18:33:41 +08:00
CaIon
bf8794d257 feat: 新增数据看板 2024-01-07 18:31:14 +08:00
CaIon
c09df83f34 feat: 允许关闭绘图选项 2024-01-07 18:31:04 +08:00
CaIon
79b1201f47 fix: 修复兑换码复制带sk- 2024-01-06 21:49:11 +08:00
CaIon
b3fc335908 修复mj计费问题 2024-01-05 12:23:26 +08:00
Calcium-Ion
290fed3841 Merge pull request #41 from Calcium-Ion/optimize/mj
optimize: MJ 部分调整、优化
2024-01-03 22:45:23 +08:00
Calcium-Ion
aa6d623981 Merge pull request #42 from Tailen/main
feat: 添加 davinci-002 和 babbage-002 模型
2024-01-03 22:43:28 +08:00
Hao Mei
afbd585048 feat: add support for davinci-002 and babbage-002 2024-01-03 20:19:00 +08:00
Xyfacai
5c747dfee2 optimize: MJ 部分调整、优化
MJ
增加simple-change、list接口,
变换和重试操作区别出来,价格与绘图一样
优化图片返回
2024-01-01 22:46:05 +08:00
CaIon
89dd0e05a0 update README.md 2023-12-29 16:08:50 +08:00
CaIon
7e4fe14871 更新用户时刷新缓存数据 2023-12-29 00:00:23 +08:00
CaIon
85411bc167 update README.md 2023-12-28 23:59:57 +08:00
CaIon
b86bc169a5 无效的请求返回具体原因 2023-12-28 23:59:46 +08:00
CaIon
bdd611fd33 feat: 加入渠道加权随机功能 2023-12-27 19:00:47 +08:00
CaIon
1a8a24698f Happy New Year 2023-12-27 18:04:02 +08:00
CaIon
c09f3b9282 enable parseTime=true 2023-12-27 16:47:32 +08:00
CaIon
e8bcf60f0a 修复不能删除注销用户的问题 2023-12-27 16:46:18 +08:00
CaIon
14592f9758 support gemini-pro-vision 2023-12-27 16:32:54 +08:00
CaIon
4036355fae bump golang.org/x/crypto from 0.14.0 to 0.17.0 2023-12-27 16:32:42 +08:00
CaIon
fa2efb7357 bump golang.org/x/crypto from 0.14.0 to 0.17.0 2023-12-27 15:39:29 +08:00
CaIon
7a8344c40a 修复用户管理显示问题 2023-12-27 15:35:16 +08:00
CaIon
6ad2544415 enable mysql parseTime 2023-12-27 15:33:35 +08:00
CaIon
7cd1261a81 Revert "fix delete user"
This reverts commit 07fe5a02
2023-12-27 15:31:28 +08:00
CaIon
45ed973f1c fix delete user 2023-12-27 15:15:07 +08:00
CaIon
07fe5a02af fix delete user 2023-12-27 14:36:19 +08:00
CaIon
7a4969c238 update Dockerfile 2023-12-27 00:19:25 +08:00
CaIon
ad6842da7f 用户注销后禁止再次注册 2023-12-27 00:19:11 +08:00
CaIon
3c52a0991f 恢复渠道优先级可设置为负数 2023-12-26 23:49:57 +08:00
Xyfacai
fd4ef086dc fix: 优化 mj 获取进度 2023-12-23 23:14:58 +08:00
CaIon
7c4719b6ee update README.md 2023-12-21 23:12:31 +08:00
CaIon
b9d040cf52 fix VERSION 2023-12-21 23:11:52 +08:00
CaIon
6e8ff8c057 fix gemini 2023-12-21 23:08:09 +08:00
CaIon
676dc95793 update README.md 2023-12-21 20:21:15 +08:00
CaIon
1c06bddafe 优化gpt-4-gizmo-*日志 2023-12-21 20:20:09 +08:00
CaIon
3475643257 支持设置模型按次计费 2023-12-21 20:14:04 +08:00
CaIon
45e1042e58 update GeneralOpenAIRequest 2023-12-20 21:20:09 +08:00
CaIon
f5a36a05e5 log输出更多有效信息 2023-12-20 20:53:36 +08:00
CaIon
5730c69385 个人设置添加修改密码功能 2023-12-20 20:53:15 +08:00
CaIon
2d33283afb fix gemini 2023-12-19 12:53:56 +08:00
CaIon
e057c0e42e 添加gemini支持 2023-12-18 23:46:05 +08:00
CaIon
b3f1da44dd 修复vision格式问题 2023-12-18 23:46:05 +08:00
CaIon
f048cefed1 update README.md 2023-12-18 23:46:05 +08:00
CaIon
0226d94ea6 修改docker-compose.yml 2023-12-18 23:46:04 +08:00
CaIon
42469cb782 修复无法指定渠道id的问题 2023-12-14 16:43:20 +08:00
CaIon
e1da1e31d5 添加批量删除渠道功能 2023-12-14 16:35:03 +08:00
CaIon
0fdd4fc6e3 美化绘画IU 2023-12-14 15:16:51 +08:00
CaIon
261dc43c4e 修复 测试所有渠道按钮 失效的问题 2023-12-14 15:16:27 +08:00
CaIon
6463e0539f Warning when SESSION_SECRET is 'random_string' 2023-12-14 11:10:14 +08:00
CaIon
c5f08a757d Merge remote-tracking branch 'public/main' into latest 2023-12-14 11:02:08 +08:00
CaIon
8a9bd08d66 update README.md 2023-12-14 11:01:53 +08:00
Calcium-Ion
751c33a6c0 Merge pull request #30 from YOMIkio/email
fix: add Date header for email
2023-12-14 10:53:39 +08:00
CaIon
57f664d0fa 修复Safari无法拉起支付的问题 2023-12-13 17:02:33 +08:00
liyujie
cdcbebce6d fix: add Date header for email 2023-12-13 16:51:19 +08:00
CaIon
a29e3765f4 修复gpts无法设置倍率的问题 2023-12-13 00:29:53 +08:00
CaIon
766e20719d update README.md 2023-12-12 16:41:51 +08:00
CaIon
4b93f185bb 添加gpt-4-1106-vision-preview模型 2023-12-12 16:41:37 +08:00
64 changed files with 2798 additions and 622 deletions

View File

@@ -1,10 +1,6 @@
MIT License
New API
Copyright (c) 2023 CalciumIon
Based on One API
Copyright (c) 2023 JustSong
Copyright (c) 2024 Calcium-Ion
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -1,5 +1,5 @@
# Neko API
# New API
> [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
@@ -29,12 +29,31 @@
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况方便二次分销
5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量
7. 支持gpt-4-1106-vision-previewdall-e-3tts-1
7. 支持 gpt-4-1106-vision-previewdall-e-3tts-1
8. 支持第三方模型 **gps** gpt-4-gizmo-*在渠道中添加自定义模型gpt-4-gizmo-*即可
9. 兼容原版One API的数据库可直接使用原版数据库one-api.db
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
11. 支持gemini-progemini-pro-vision模型
12. 支持渠道**加权随机**
13. 数据看板
14. 可设置令牌能调用的模型
## 部署
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="500">
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
## 界面截图
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/ad0e7aae-0203-471c-9716-2d83768927d4)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/d1ac216e-0804-4105-9fdc-66b35022d861)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)
@@ -42,8 +61,11 @@
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/90d7d763-6a77-4b36-9f76-2bb30f18583d)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/e414228a-3c35-429a-b298-6451d76d9032)
夜间模式
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/1c66b593-bb9e-4757-9720-ff2759539242)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/5b3228e8-2556-44f7-97d6-4f8d8ee6effa)
![image](https://github.com/Calcium-Ion/new-api/assets/61247483/af9a07ee-5101-4b3d-8bd9-ae21a4fd7e9e)
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=Calcium-Ion/new-api&type=Date)](https://star-history.com/#Calcium-Ion/new-api&Date)

View File

@@ -11,7 +11,7 @@ import (
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var PayAddress = ""
var EpayId = ""
@@ -24,6 +24,9 @@ var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var DrawingEnabled = true
var DataExportEnabled = true
var DataExportInterval = 5 // unit: minute
// Any options with "Secret", "Token" in its key won't be return by GetOptions
@@ -190,6 +193,7 @@ const (
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
)
var ChannelBaseURLs = []string{

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
@@ -16,8 +17,9 @@ func SendEmail(subject string, receiver string, content string) error {
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, content))
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")

View File

@@ -12,7 +12,7 @@ import (
"strings"
)
func DecodeBase64ImageData(base64String string) (image.Config, error) {
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
// 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:]
@@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) {
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, err
return image.Config{}, "", err
}
// 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData)
config, err := getImageConfig(reader)
return config, err
config, format, err := getImageConfig(reader)
return config, format, err
}
func DecodeUrlImageData(imageUrl string) (image.Config, error) {
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := http.Get(imageUrl)
if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, err
return image.Config{}, "", err
}
// 限制读取的字节数,防止下载整个图片
@@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) {
// log.Fatal(err)
//}
//log.Printf("%x", data)
config, err := getImageConfig(limitReader)
config, format, err := getImageConfig(limitReader)
response.Body.Close()
return config, err
return config, format, err
}
func getImageConfig(reader io.Reader) (image.Config, error) {
func getImageConfig(reader io.Reader) (image.Config, string, error) {
// 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(reader)
config, format, err := image.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
SysLog(err.Error())
@@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
SysLog(err.Error())
}
format = "webp"
}
if err != nil {
return image.Config{}, err
return image.Config{}, "", err
}
return config, nil
return config, format, nil
}

View File

@@ -16,7 +16,7 @@ var (
)
func printHelp() {
fmt.Println("One API " + Version + " - All in one API service for OpenAI API.")
fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
@@ -36,7 +36,14 @@ func init() {
}
if os.Getenv("SESSION_SECRET") != "" {
SessionSecret = os.Getenv("SESSION_SECRET")
ss := os.Getenv("SESSION_SECRET")
if ss == "random_string" {
log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
log.Println("警告SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
log.Fatal("Please set SESSION_SECRET to a random string.")
} else {
SessionSecret = ss
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")

View File

@@ -14,7 +14,8 @@ import (
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"midjourney": 50,
//"midjourney": 50,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
@@ -23,6 +24,7 @@ var ModelRatio = map[string]float64{
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
@@ -30,6 +32,8 @@ var ModelRatio = map[string]float64{
"gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"text-ada-001": 0.2,
"text-babbage-001": 0.25,
"text-curie-001": 1,
@@ -59,6 +63,8 @@ var ModelRatio = map[string]float64{
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@@ -74,6 +80,41 @@ var ModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
var ModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
}
func ModelPrice2JSONString() string {
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelPriceByJSONString(jsonStr string) error {
ModelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
}
func GetModelPrice(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
price, ok := ModelPrice[name]
if !ok {
SysError("model price not found: " + name)
return -1
}
return price
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
@@ -88,6 +129,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)

View File

@@ -168,6 +168,11 @@ func GetRandomString(length int) string {
return string(key)
}
func GetRandomInt(max int) int {
//rand.Seed(time.Now().UnixNano())
return rand.Intn(max)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}

View File

@@ -29,6 +29,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:

View File

@@ -151,6 +151,36 @@ func DeleteDisabledChannel(c *gin.Context) {
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)

View File

@@ -16,8 +16,10 @@ import (
"time"
)
func UpdateMidjourneyTask() {
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
@@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
log.Printf("检测到未完成的任务数有: %v", len(tasks))
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
log.Printf("未完成的任务信息: %v", task)
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
log.Printf("requestUrl: %s", requestUrl)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
@@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
log.Println(task.MjId + " 构建失败," + task.FailReason)
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
@@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, 1, logContent)
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
@@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
}
}
}
*/
func UpdateMidjourneyTaskBulk() {
//revocer
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) == 0 {
continue
}
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
for _, task := range tasks {
if task.MjId == "" {
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(taskChannelM) == 0 {
continue
}
for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
return false
}
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
@@ -187,6 +364,12 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if !strings.Contains(common.ServerAddress, "localhost") {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",

View File

@@ -34,6 +34,8 @@ func GetStatus(c *gin.Context) {
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_data_export": common.DataExportEnabled,
},
})
return

View File

@@ -261,6 +261,15 @@ func init() {
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "gpt-4-1106-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
@@ -351,6 +360,24 @@ func init() {
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "babbage-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "babbage-002",
Parent: nil,
},
{
Id: "davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "davinci-002",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
@@ -414,6 +441,24 @@ func init() {
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",

View File

@@ -18,7 +18,7 @@ type ClaudeMetadata struct {
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`

336
controller/relay-gemini.go Normal file
View File

@@ -0,0 +1,336 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
// {
// Category: "HARM_CATEGORY_HARASSMENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_HATE_SPEECH",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
//},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: string(message.Content),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
Index: i,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: stopFinishReason,
}
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = content
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -55,9 +55,18 @@ type MidjourneyWithoutStatus struct {
ChannelId int `json:"channel_id"`
}
var DefaultModelPrice = map[string]float64{
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05,
}
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByMJId(taskId)
midjourneyTask := model.GetByOnlyMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
@@ -71,14 +80,27 @@ func RelayMidjourneyImage(c *gin.Context) {
})
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
c.JSON(resp.StatusCode, gin.H{
"error": string(responseBody),
})
return
}
c.Header("Content-Type", "image/jpeg")
//c.HeaderBar("Content-Length", string(rune(len(data))))
c.Data(http.StatusOK, "image/jpeg", data)
// 从Content-Type头获取MIME类型
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// 如果无法确定内容类型则默认为jpeg
contentType = "image/jpeg"
}
// 设置响应的内容类型
c.Writer.Header().Set("Content-Type", contentType)
// 将图片流式传输到响应体
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println("Failed to stream image:", err)
}
return
}
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
@@ -92,7 +114,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
Result: "",
}
}
midjourneyTask := model.GetByMJId(midjRequest.MjId)
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &MidjourneyResponse{
Code: 4,
@@ -121,16 +143,7 @@ func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
return nil
}
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
taskId := c.Param("id")
originTask := model.GetByMJId(taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
var midjourneyTask Midjourney
func getMidjourneyTaskModel(c *gin.Context, originTask *model.Midjourney) (midjourneyTask Midjourney) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
@@ -150,14 +163,65 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
jsonMap, err := json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
return
}
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
userId := c.GetInt("id")
var err error
var respBody []byte
switch relayMode {
case RelayModeMidjourneyTaskFetch:
taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
midjourneyTask := getMidjourneyTaskModel(c, originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
case RelayModeMidjourneyTaskFetchByCondition:
var condition = struct {
IDs []string `json:"ids"`
}{}
err = c.BindJSON(&condition)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
var tasks []Midjourney
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := getMidjourneyTaskModel(c, originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]Midjourney, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
c.Writer.Header().Set("Content-Type", "application/json")
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return &MidjourneyResponse{
Code: 4,
@@ -167,6 +231,18 @@ func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
return nil
}
const (
// type 1 根据 mode 价格不同
MJSubmitActionImagine = "IMAGINE"
MJSubmitActionVariation = "VARIATION" //变换
MJSubmitActionBlend = "BLEND" //混图
MJSubmitActionReroll = "REROLL" //重新生成
// type 2 固定价格
MJSubmitActionDescribe = "DESCRIBE"
MJSubmitActionUpscale = "UPSCALE" // 放大
)
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney"
@@ -186,6 +262,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
}
}
if relayMode == RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
@@ -199,7 +276,45 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
} else if relayMode == RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = "BLEND"
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
originTask := model.GetByMJId(midjRequest.TaskId)
mjId := ""
if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return &MidjourneyResponse{
Code: 4,
Description: "content_is_required",
}
}
params := convertSimpleChangeParams(midjRequest.Content)
if params == nil {
return &MidjourneyResponse{
Code: 4,
Description: "content_parse_failed",
}
}
mjId = params.ID
midjRequest.Action = params.Action
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
@@ -229,23 +344,6 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
log.Printf("检测到此操作为放大、变换获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
}
midjRequest.Prompt = originTask.Prompt
} else if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
}
// map model name
@@ -292,18 +390,27 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
if midjRequest.Action == "UPSCALE" {
sizeRatio = 0.2
mjAction := "mj_" + strings.ToLower(midjRequest.Action)
modelPrice := common.GetModelPrice(mjAction)
// 如果没有配置价格,则使用默认价格
if modelPrice == -1 {
defaultPrice, ok := DefaultModelPrice[mjAction]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
quota := int(ratio * sizeRatio * 1000)
groupRatio := common.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: err.Error(),
}
}
quota := int(ratio * common.QuotaPerUnit)
if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{
@@ -369,7 +476,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, midjRequest.Action)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
@@ -504,3 +611,38 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
}
return nil
}
type taskChangeParams struct {
ID string
Action string
Index int
}
func convertSimpleChangeParams(content string) *taskChangeParams {
split := strings.Split(content, " ")
if len(split) != 2 {
return nil
}
action := strings.ToLower(split[1])
changeParams := &taskChangeParams{}
changeParams.ID = split[0]
if action[0] == 'u' {
changeParams.Action = "UPSCALE"
} else if action[0] == 'v' {
changeParams.Action = "VARIATION"
} else if action == "r" {
changeParams.Action = "REROLL"
return changeParams
} else {
return nil
}
index, err := strconv.Atoi(action[1:2])
if err != nil || index < 1 || index > 4 {
return nil
}
changeParams.Index = index
return changeParams
}

View File

@@ -31,7 +31,7 @@ type PaLMChatRequest struct {
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {

View File

@@ -26,6 +26,7 @@ const (
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
@@ -119,6 +120,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -180,6 +183,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1beta"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
//log.Println(fullRequestURL)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
@@ -209,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
modelPrice := common.GetModelPrice(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
var preConsumedQuota int
var ratio float64
var modelRatio float64
if modelPrice == -1 {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -280,6 +312,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
@@ -375,13 +414,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypeGemini:
req.Header.Set("Content-Type", "application/json")
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
if apiType != APITypeGemini {
// 设置公共头部...
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
resp, err = httpClient.Do(req)
@@ -418,15 +462,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
@@ -443,10 +491,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
// record all the consume log even if quota is 0
useTimeSeconds := time.Now().Unix() - startTime.Unix()
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
}
logModel := textRequest.Model
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
//if quota != 0 {
@@ -539,6 +599,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)

View File

@@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
}
var config image.Config
var err error
var format string
if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, err = common.DecodeUrlImageData(imageUrl.Url)
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, err = common.DecodeBase64ImageData(imageUrl.Url)
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
return 0, err
@@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
shortSide := config.Width
otherSide := config.Height
log.Printf("width: %d, height: %d", config.Width, config.Height)
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
// 缩放倍数
scale := 1.0
if config.Height < shortSide {
@@ -160,10 +161,27 @@ func countTokenMessages(messages []Message, model string) (int, error) {
} else {
for _, m := range arrayContent {
if m.Type == "image_url" {
imageTokenNum, err := getImageToken(&m.ImageUrl)
var imageTokenNum int
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else {
@@ -177,12 +195,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
}
func countTokenInput(input any, model string) int {
switch input.(type) {
switch v := input.(type) {
case string:
return countTokenText(input.(string), model)
return countTokenText(v, model)
case []string:
text := ""
for _, s := range input.([]string) {
for _, s := range v {
text += s
}
return countTokenText(text, model)

View File

@@ -33,7 +33,7 @@ type XunfeiChatRequest struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`

View File

@@ -19,9 +19,9 @@ type Message struct {
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl MessageImageUrl `json:"image_url,omitempty"`
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
@@ -29,6 +29,60 @@ type MessageImageUrl struct {
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
@@ -41,8 +95,10 @@ const (
RelayModeMidjourneyDescribe
RelayModeMidjourneyBlend
RelayModeMidjourneyChange
RelayModeMidjourneySimpleChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudio
)
@@ -53,7 +109,7 @@ type GeneralOpenAIRequest struct {
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
@@ -91,14 +147,14 @@ type AudioRequest struct {
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
MaxTokens uint `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
MaxTokens uint `json:"max_tokens"`
//Stream bool `json:"stream"`
}
@@ -209,6 +265,7 @@ type MidjourneyRequest struct {
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
Content string `json:"content"`
}
type MidjourneyResponse struct {
@@ -288,14 +345,19 @@ func RelayMidjourney(c *gin.Context) {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/simple-change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasSuffix(c.Request.URL.Path, "/fetch") {
relayMode = RelayModeMidjourneyTaskFetch
} else if strings.HasSuffix(c.Request.URL.Path, "/list-by-condition") {
relayMode = RelayModeMidjourneyTaskFetchByCondition
}
var err *MidjourneyResponse
switch relayMode {
case RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch:
case RelayModeMidjourneyTaskFetch, RelayModeMidjourneyTaskFetchByCondition:
err = relayMidjourneyTask(c, relayMode)
default:
err = relayMidjourneySubmit(c, relayMode)

View File

@@ -217,6 +217,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.ModelLimitsEnabled = token.ModelLimitsEnabled
cleanToken.ModelLimits = token.ModelLimits
}
err = cleanToken.Update()
if err != nil {

48
controller/usedata.go Normal file
View File

@@ -0,0 +1,48 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/model"
"strconv"
)
func GetAllQuotaDates(c *gin.Context) {
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
username := c.Query("username")
dates, err := model.GetAllQuotaDates(startTimestamp, endTimestamp, username)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dates,
})
return
}
func GetUserQuotaDates(c *gin.Context) {
userId := c.GetInt("id")
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dates,
})
return
}

View File

@@ -152,6 +152,21 @@ func Register(c *gin.Context) {
return
}
}
exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if exist {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户名已存在,或已注销",
})
return
}
affCode := user.AffCode // this code is the inviter's code, not the user's own code
inviterId, _ := model.GetUserIdByAffCode(affCode)
cleanUser := model.User{
@@ -525,7 +540,7 @@ func DeleteUser(c *gin.Context) {
})
return
}
err = model.DeleteUserById(id)
err = model.HardDeleteUserById(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -632,7 +647,7 @@ func ManageUser(c *gin.Context) {
Username: req.Username,
}
// Fill attributes
model.DB.Where(&user).First(&user)
model.DB.Unscoped().Where(&user).First(&user)
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -668,7 +683,7 @@ func ManageUser(c *gin.Context) {
})
return
}
if err := user.Delete(); err != nil {
if err := user.HardDelete(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -1,9 +1,9 @@
version: '3.4'
services:
one-api:
new-api:
image: calciumion/new-api:latest
container_name: one-api
container_name: new-api
restart: always
command: --log-dir /app/logs
ports:
@@ -12,7 +12,7 @@ services:
- ./data:/data
- ./logs:/app/logs
environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai

12
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/samber/lo v1.38.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.14.0
golang.org/x/crypto v0.17.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@@ -44,13 +44,14 @@ require (
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jackc/pgx/v5 v5.5.1 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -64,8 +65,9 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

14
go.sum
View File

@@ -81,6 +81,10 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -108,6 +112,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
@@ -172,11 +178,15 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -189,12 +199,16 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=

18
main.go
View File

@@ -27,7 +27,7 @@ var indexPage []byte
func main() {
common.SetupLogger()
common.SysLog("One API " + common.Version + " started")
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
@@ -67,6 +67,10 @@ func main() {
go model.SyncOptions(common.SyncFrequency)
go model.SyncChannelCache(common.SyncFrequency)
}
// 数据看板
go model.UpdateQuotaData()
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
@@ -81,7 +85,7 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
go controller.UpdateMidjourneyTask()
go controller.UpdateMidjourneyTaskBulk()
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
@@ -100,7 +104,15 @@ func main() {
// Initialize HTTP server
server := gin.New()
server.Use(gin.Recovery())
server.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
common.SysError(fmt.Sprintf("panic detected: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
"type": "new_api_panic",
},
})
}))
// This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.RequestId())

View File

@@ -91,11 +91,11 @@ func TokenAuth() func(c *gin.Context) {
key = c.Request.Header.Get("mj-api-secret")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
parts = strings.Split(key, "-")
key = parts[0]
} else {
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
parts = strings.Split(key, "-")
key = parts[0]
}
token, err := model.ValidateUserToken(key)
@@ -115,6 +115,12 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())
} else {
c.Set("token_model_limit_enabled", false)
}
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {

View File

@@ -50,7 +50,7 @@ func Distribute() func(c *gin.Context) {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
abortWithMessage(c, http.StatusBadRequest, "无效的请求: "+err.Error())
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -77,6 +77,27 @@ func Distribute() func(c *gin.Context) {
}
}
}
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
abortWithMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@@ -107,6 +128,8 @@ func Distribute() func(c *gin.Context) {
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
}
c.Next()
}

28
middleware/recover.go Normal file
View File

@@ -0,0 +1,28 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"runtime/debug"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
common.SysError(fmt.Sprintf("panic detected: %v", err))
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/Calcium-Ion/new-api", err),
"type": "new_api_panic",
},
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -6,11 +6,12 @@ import (
)
type Ability struct {
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
}
func GetGroupModels(group string) []string {
@@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var abilities []Ability
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
@@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
}
if err != nil {
return nil, err
}
channel := Channel{}
channel.Id = ability.ChannelId
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
if len(abilities) > 0 {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight
}
if weightSum == 0 {
// All weight is 0, randomly choose one
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
} else {
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
}
}
} else {
return nil, nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
}
@@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
}
abilities = append(abilities, ability)
}

View File

@@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@@ -149,10 +150,12 @@ func InitChannelCache() {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
newChannelsIDM := make(map[int]*Channel)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
}
for _, channel := range channels {
newChannelsIDM[channel.Id] = channel
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
@@ -177,6 +180,7 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelsIDM = newChannelsIDM
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@@ -194,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
model = "gpt-4-gizmo-*"
}
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
@@ -214,6 +219,41 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
}
if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
return channel, nil
}
}
// return the last channel if no channel is found
return channels[endIdx-1], nil
}
func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, errors.New(fmt.Sprintf("当前渠道# %d已不存在", id))
}
return c, nil
}

View File

@@ -21,7 +21,7 @@ type Channel struct {
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
@@ -86,6 +86,26 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func BatchDeleteChannels(ids []int) error {
//使用事务 删除channel表和channel_ability表
tx := DB.Begin()
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
if err != nil {
// 回滚事务
tx.Rollback()
return err
}
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
if err != nil {
// 回滚事务
tx.Rollback()
return err
}
// 提交事务
tx.Commit()
return err
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
@@ -93,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
return *channel.Priority
}
func (channel *Channel) GetWeight() int {
if channel.Weight == nil {
return 0
}
return int(*channel.Weight)
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""

View File

@@ -9,7 +9,7 @@ import (
)
type Log struct {
Id int `json:"id;index:idx_created_at_id,priority:1"`
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"`
@@ -59,9 +59,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
if !common.LogConsumeEnabled {
return
}
username := GetUsernameById(userId)
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
Username: username,
CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
@@ -77,6 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
if err != nil {
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp())
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {

View File

@@ -52,6 +52,14 @@ func chooseDB() (*gorm.DB, error) {
}
// Use MySQL
common.SysLog("using MySQL as database")
// check parseTime
if !strings.Contains(dsn, "parseTime") {
if strings.Contains(dsn, "?") {
dsn += "&parseTime=true"
} else {
dsn += "?parseTime=true"
}
}
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
@@ -119,6 +127,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&QuotaData{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err

View File

@@ -96,7 +96,7 @@ func GetAllUnFinishTasks() []*Midjourney {
return tasks
}
func GetByMJId(mjId string) *Midjourney {
func GetByOnlyMJId(mjId string) *Midjourney {
var mj *Midjourney
var err error
err = DB.Where("mj_id = ?", mjId).First(&mj).Error
@@ -106,6 +106,26 @@ func GetByMJId(mjId string) *Midjourney {
return mj
}
func GetByMJId(userId int, mjId string) *Midjourney {
var mj *Midjourney
var err error
err = DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
var mj []*Midjourney
var err error
err = DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
if err != nil {
return nil
}
return mj
}
func GetMjByuId(id int) *Midjourney {
var mj *Midjourney
var err error
@@ -131,3 +151,9 @@ func (midjourney *Midjourney) Update() error {
err = DB.Save(midjourney).Error
return err
}
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs).
Updates(params).Error
}

View File

@@ -37,6 +37,8 @@ func InitOptionMap() {
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",")
@@ -70,11 +72,13 @@ func InitOptionMap() {
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -156,6 +160,12 @@ func updateOptionMap(key string, value string) (err error) {
common.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled":
common.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled":
common.DisplayTokenStatEnabled = boolValue
case "DrawingEnabled":
common.DrawingEnabled = boolValue
case "DataExportEnabled":
common.DataExportEnabled = boolValue
}
}
switch key {
@@ -216,10 +226,14 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
common.DataExportInterval, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "ModelPrice":
err = common.UpdateModelPriceByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChatLink":

View File

@@ -10,17 +10,19 @@ import (
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
ModelLimitsEnabled bool `json:"model_limits_enabled" gorm:"default:false"`
ModelLimits string `json:"model_limits" gorm:"type:varchar(1024);default:''"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@@ -45,7 +47,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
return nil, errors.New("该令牌额度已用尽 token.Status == common.TokenStatusExhausted " + key)
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
@@ -71,7 +73,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
return nil, errors.New(fmt.Sprintf("%s 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", token.Key, token.RemainQuota))
}
return token, nil
}
@@ -107,7 +109,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "model_limits_enabled", "model_limits").Updates(token).Error
return err
}
@@ -122,6 +124,36 @@ func (token *Token) Delete() error {
return err
}
func (token *Token) IsModelLimitsEnabled() bool {
return token.ModelLimitsEnabled
}
func (token *Token) GetModelLimits() []string {
if token.ModelLimits == "" {
return []string{}
}
return strings.Split(token.ModelLimits, ",")
}
func (token *Token) GetModelLimitsMap() map[string]bool {
limits := token.GetModelLimits()
limitsMap := make(map[string]bool)
for _, limit := range limits {
limitsMap[limit] = true
}
return limitsMap
}
func DisableModelLimits(tokenId int) error {
token, err := GetTokenById(tokenId)
if err != nil {
return err
}
token.ModelLimitsEnabled = false
token.ModelLimits = ""
return token.Update()
}
func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 {

115
model/usedata.go Normal file
View File

@@ -0,0 +1,115 @@
package model
import (
"fmt"
"one-api/common"
"sync"
"time"
)
// QuotaData 柱状图数据
type QuotaData struct {
Id int `json:"id"`
UserID int `json:"user_id" gorm:"index"`
Username string `json:"username" gorm:"index:idx_qdt_model_user_name,priority:2;size:64;default:''"`
ModelName string `json:"model_name" gorm:"index:idx_qdt_model_user_name,priority:1;size:64;default:''"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_qdt_created_at,priority:2"`
Count int `json:"count" gorm:"default:0"`
Quota int `json:"quota" gorm:"default:0"`
}
func UpdateQuotaData() {
// recover
defer func() {
if r := recover(); r != nil {
common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
}
}()
for {
if common.DataExportEnabled {
common.SysLog("正在更新数据看板数据...")
SaveQuotaDataCache()
}
time.Sleep(time.Duration(common.DataExportInterval) * time.Minute)
}
}
var CacheQuotaData = make(map[string]*QuotaData)
var CacheQuotaDataLock = sync.Mutex{}
func LogQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64) {
// 只精确到小时
createdAt = createdAt - (createdAt % 3600)
key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
quotaData, ok := CacheQuotaData[key]
if ok {
quotaData.Count += 1
quotaData.Quota += quota
} else {
quotaData = &QuotaData{
UserID: userId,
Username: username,
ModelName: modelName,
CreatedAt: createdAt,
Count: 1,
Quota: quota,
}
}
CacheQuotaData[key] = quotaData
}
func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64) {
CacheQuotaDataLock.Lock()
defer CacheQuotaDataLock.Unlock()
LogQuotaDataCache(userId, username, modelName, quota, createdAt)
}
func SaveQuotaDataCache() {
CacheQuotaDataLock.Lock()
defer CacheQuotaDataLock.Unlock()
size := len(CacheQuotaData)
// 如果缓存中有数据,就保存到数据库中
// 1. 先查询数据库中是否有数据
// 2. 如果有数据,就更新数据
// 3. 如果没有数据,就插入数据
for _, quotaData := range CacheQuotaData {
quotaDataDB := &QuotaData{}
DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.CreatedAt).First(quotaDataDB)
if quotaDataDB.Id > 0 {
quotaDataDB.Count += quotaData.Count
quotaDataDB.Quota += quotaData.Quota
DB.Table("quota_data").Save(quotaDataDB)
} else {
DB.Table("quota_data").Create(quotaData)
}
}
CacheQuotaData = make(map[string]*QuotaData)
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
}
func GetQuotaDataByUsername(username string, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
err = DB.Table("quota_data").Where("username = ?", username).Find(&quotaDatas).Error
return quotaDatas, err
}
func GetQuotaDataByUserId(userId int, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
err = DB.Table("quota_data").Where("user_id = ? and created_at >= ? and created_at <= ?", userId, startTime, endTime).Find(&quotaDatas).Error
return quotaDatas, err
}
func GetAllQuotaDates(startTime int64, endTime int64, username string) (quotaData []*QuotaData, err error) {
if username != "" {
return GetQuotaDataByUsername(username, startTime, endTime)
}
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
// only select model_name, sum(count) as count, sum(quota) as quota, model_name, created_at from quota_data group by model_name, created_at;
//err = DB.Table("quota_data").Where("created_at >= ? and created_at <= ?", startTime, endTime).Find(&quotaDatas).Error
err = DB.Table("quota_data").Select("model_name, sum(count) as count, sum(quota) as quota, created_at").Where("created_at >= ? and created_at <= ?", startTime, endTime).Group("model_name, created_at").Find(&quotaDatas).Error
return quotaDatas, err
}

View File

@@ -6,31 +6,57 @@ import (
"gorm.io/gorm"
"one-api/common"
"strings"
"time"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
// err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
// check email if empty
var err error
if email == "" {
err = DB.Unscoped().First(&user, "username = ?", username).Error
} else {
err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
}
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// not exist, return false, nil
return false, nil
}
// other error, return false, err
return false, err
}
// exist, return true, nil
return true, nil
}
func GetMaxUserId() int {
@@ -40,7 +66,7 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
return users, err
}
@@ -80,6 +106,14 @@ func DeleteUserById(id int) (err error) {
return user.Delete()
}
func HardDeleteUserById(id int) error {
if id == 0 {
return errors.New("id 为空!")
}
err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
return err
}
func inviteUser(inviterId int) (err error) {
user, err := GetUserById(inviterId, true)
if err != nil {
@@ -169,9 +203,13 @@ func (user *User) Update(updatePassword bool) error {
}
}
newUser := *user
DB.First(&user, user.Id)
err = DB.Model(user).Updates(newUser).Error
if err == nil {
if common.RedisEnabled {
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
}
}
return err
}
@@ -183,6 +221,14 @@ func (user *User) Delete() error {
return err
}
func (user *User) HardDelete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Unscoped().Delete(user).Error
return err
}
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,

View File

@@ -83,6 +83,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())
@@ -113,6 +114,10 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
dataRoute := apiRouter.Group("/data")
dataRoute.GET("/", middleware.AdminAuth(), controller.GetAllQuotaDates)
dataRoute.GET("/self", middleware.UserAuth(), controller.GetUserQuotaDates)
logRoute.Use(middleware.CORS())
{
logRoute.GET("/token", controller.GetLogByKey)

View File

@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)
@@ -49,10 +49,12 @@ func SetRelayRouter(router *gin.Engine) {
{
relayMjRouter.POST("/submit/imagine", controller.RelayMidjourney)
relayMjRouter.POST("/submit/change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/simple-change", controller.RelayMidjourney)
relayMjRouter.POST("/submit/describe", controller.RelayMidjourney)
relayMjRouter.POST("/submit/blend", controller.RelayMidjourney)
relayMjRouter.POST("/notify", controller.RelayMidjourney)
relayMjRouter.GET("/task/:id/fetch", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
}
//relayMjRouter.Use()
}

View File

@@ -3,13 +3,17 @@
"version": "0.1.0",
"private": true,
"dependencies": {
"@douyinfe/semi-ui": "^2.45.2",
"@douyinfe/semi-ui": "^2.46.1",
"@visactor/vchart": "~1.7.2",
"@visactor/react-vchart": "~1.7.2",
"@visactor/vchart-semi-theme": "~1.7.2",
"axios": "^0.27.2",
"history": "^5.3.0",
"marked": "^4.1.1",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-fireworks": "^1.0.4",
"react-router-dom": "^6.3.0",
"react-scripts": "5.0.1",
"react-toastify": "^9.0.8",
@@ -43,7 +47,8 @@
]
},
"devDependencies": {
"prettier": "^2.7.1"
"prettier": "^2.7.1",
"typescript": "4.4.2"
},
"prettier": {
"singleQuote": true,

View File

@@ -23,10 +23,10 @@ import Log from './pages/Log';
import Chat from './pages/Chat';
import {Layout} from "@douyinfe/semi-ui";
import Midjourney from "./pages/Midjourney";
import Detail from "./pages/Detail";
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
function App() {
const [userState, userDispatch] = useContext(UserContext);
const [statusState, statusDispatch] = useContext(StatusContext);
@@ -49,6 +49,8 @@ function App() {
localStorage.setItem('footer_html', data.footer_html);
localStorage.setItem('quota_per_unit', data.quota_per_unit);
localStorage.setItem('display_in_currency', data.display_in_currency);
localStorage.setItem('enable_drawing', data.enable_drawing);
localStorage.setItem('enable_data_export', data.enable_data_export);
if (data.chat_link) {
localStorage.setItem('chat_link', data.chat_link);
} else {
@@ -228,6 +230,14 @@ function App() {
</PrivateRoute>
}
/>
<Route
path='/detail'
element={
<PrivateRoute>
<Detail />
</PrivateRoute>
}
/>
<Route
path='/midjourney'
element={

View File

@@ -74,6 +74,11 @@ function renderBalance(type, balance) {
const ChannelsTable = () => {
const columns = [
// {
// title: '',
// dataIndex: 'checkbox',
// className: 'checkbox',
// },
{
title: 'ID',
dataIndex: 'id',
@@ -158,11 +163,30 @@ const ChannelsTable = () => {
<div>
<InputNumber
style={{width: 70}}
name='name'
name='priority'
onChange={value => {
manageChannel(record.id, 'priority', record, value);
}}
defaultValue={record.priority}
min={-999}
/>
</div>
);
},
},
{
title: '权重',
dataIndex: 'weight',
render: (text, record, index) => {
return (
<div>
<InputNumber
style={{width: 70}}
name='weight'
onChange={value => {
manageChannel(record.id, 'weight', record, value);
}}
defaultValue={record.weight}
min={0}
/>
</div>
@@ -235,9 +259,11 @@ const ChannelsTable = () => {
const [channelCount, setChannelCount] = useState(pageSize);
const [groupOptions, setGroupOptions] = useState([]);
const [showEdit, setShowEdit] = useState(false);
const [enableBatchDelete, setEnableBatchDelete] = useState(false);
const [editingChannel, setEditingChannel] = useState({
id: undefined,
});
const [selectedChannels, setSelectedChannels] = useState([]);
const removeRecord = id => {
let newDataSource = [...channels];
@@ -484,6 +510,27 @@ const ChannelsTable = () => {
setUpdatingBalance(false);
};
const batchDeleteChannels = async () => {
if (selectedChannels.length === 0) {
showError('请先选择要删除的通道!');
return;
}
setLoading(true);
let ids = [];
selectedChannels.forEach((channel) => {
ids.push(channel.id);
});
const res = await API.post(`/api/channel/batch`, {ids: ids});
const {success, message, data} = res.data;
if (success) {
showSuccess(`已删除 ${data} 个通道!`);
await refresh();
} else {
showError(message);
}
setLoading(false);
}
const sortChannel = (key) => {
if (channels.length === 0) return;
setLoading(true);
@@ -557,6 +604,7 @@ const ChannelsTable = () => {
}
};
return (
<>
<EditChannel refresh={refresh} visible={showEdit} handleClose={closeEdit} editingChannel={editingChannel}/>
@@ -583,16 +631,18 @@ const ChannelsTable = () => {
</Form>
<div style={{marginTop: 10, display: 'flex'}}>
<Space>
<Typography.Text strong>使用ID排序</Typography.Text>
<Switch checked={idSort} label='使用ID排序' uncheckedText="关" aria-label="是否用ID排序" onChange={(v) => {
localStorage.setItem('id-sort', v + '')
setIdSort(v)
loadChannels(0, pageSize, v)
.then()
.catch((reason) => {
showError(reason);
})
}}></Switch>
<Space>
<Typography.Text strong>使用ID排序</Typography.Text>
<Switch checked={idSort} label='使用ID排序' uncheckedText="关" aria-label="是否用ID排序" onChange={(v) => {
localStorage.setItem('id-sort', v + '')
setIdSort(v)
loadChannels(0, pageSize, v)
.then()
.catch((reason) => {
showError(reason);
})
}}></Switch>
</Space>
</Space>
</div>
@@ -607,7 +657,15 @@ const ChannelsTable = () => {
handlePageSizeChange(size).then()
},
onPageChange: handlePageChange,
}} loading={loading} onRow={handleRow}/>
}} loading={loading} onRow={handleRow} rowSelection={
enableBatchDelete ?
{
onChange: (selectedRowKeys, selectedRows) => {
// console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows);
setSelectedChannels(selectedRows);
},
} : null
}/>
<div style={{display: isMobile()?'':'flex', marginTop: isMobile()?0:-45, zIndex: 999, position: 'relative', pointerEvents: 'none'}}>
<Space style={{pointerEvents: 'auto'}}>
<Button theme='light' type='primary' style={{marginRight: 8}} onClick={
@@ -622,7 +680,7 @@ const ChannelsTable = () => {
title="确定?"
okType={'warning'}
onConfirm={testAllChannels}
position={isMobile()?'top':''}
position={isMobile()?'top':'top'}
>
<Button theme='light' type='warning' style={{marginRight: 8}}>测试所有已启用通道</Button>
</Popconfirm>
@@ -648,6 +706,24 @@ const ChannelsTable = () => {
{/*</div>*/}
</div>
<div style={{marginTop: 20}}>
<Space>
<Typography.Text strong>开启批量删除</Typography.Text>
<Switch label='开启批量删除' uncheckedText="关" aria-label="是否开启批量删除" onChange={(v) => {
setEnableBatchDelete(v)
}}></Switch>
<Popconfirm
title="确定是否要删除所选通道?"
content="此修改将不可逆"
okType={'danger'}
onConfirm={batchDeleteChannels}
disabled={!enableBatchDelete}
position={'top'}
>
<Button disabled={!enableBatchDelete} theme='light' type='danger' style={{marginRight: 8}}>删除所选通道</Button>
</Popconfirm>
</Space>
</div>
</>
);
};

View File

@@ -1,4 +1,4 @@
import React, {useContext, useEffect, useState} from 'react';
import React, {useContext, useEffect, useRef, useState} from 'react';
import {Link, useNavigate} from 'react-router-dom';
import {UserContext} from '../context/User';
@@ -6,18 +6,12 @@ import {Button, Container, Icon, Menu, Segment} from 'semantic-ui-react';
import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers';
import '../index.css';
import fireworks from 'react-fireworks';
import {
IconAt,
IconHistogram,
IconGift,
IconKey,
IconUser,
IconLayers,
IconHelpCircle,
IconCreditCard,
IconSemiLogo,
IconHome,
IconImage
IconHelpCircle
} from '@douyinfe/semi-icons';
import {Nav, Avatar, Dropdown, Layout, Switch} from '@douyinfe/semi-ui';
import {stringToColor} from "../helpers/render";
@@ -49,6 +43,8 @@ const HeaderBar = () => {
const systemName = getSystemName();
const logo = getLogo();
var themeMode = localStorage.getItem('theme-mode');
const currentDate = new Date();
const isNewYear = currentDate.getMonth() === 0 && currentDate.getDate() === 1;
async function logout() {
setShowSidebar(false);
@@ -59,10 +55,24 @@ const HeaderBar = () => {
navigate('/login');
}
const handleNewYearClick = () => {
fireworks.init("root",{});
fireworks.start();
setTimeout(() => {
fireworks.stop();
setTimeout(() => {
window.location.reload();
}, 10000);
}, 3000);
};
useEffect(() => {
if (themeMode === 'dark') {
switchMode(true);
}
if (isNewYear) {
console.log('Happy New Year!');
}
}, []);
const switchMode = (model) => {
@@ -105,6 +115,19 @@ const HeaderBar = () => {
}}
footer={
<>
{isNewYear &&
// happy new year
<Dropdown
position="bottomRight"
render={
<Dropdown.Menu>
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
</Dropdown.Menu>
}
>
<Nav.Item itemKey={'new-year'} text={'🏮'}/>
</Dropdown>
}
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
{userState.user ?

View File

@@ -287,7 +287,7 @@ const LogsTable = () => {
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
console.log(logCount);
// console.log(logCount);
}
const loadLogs = async (startIdx) => {
@@ -422,7 +422,6 @@ const LogsTable = () => {
value={end_timestamp} type='dateTime'
name='end_timestamp'
onChange={value => handleInputChange(value, 'end_timestamp')}/>
{/*<Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>*/}
{
isAdminUser && <>
<Form.Input field="channel" label='渠道 ID' style={{width: 176}} value={channel}

View File

@@ -2,7 +2,19 @@ import React, {useEffect, useState} from 'react';
import {Label} from 'semantic-ui-react';
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal } from '@douyinfe/semi-ui';
import {
Table,
Avatar,
Tag,
Form,
Button,
Layout,
Select,
Popover,
Modal,
ImagePreview,
Typography
} from '@douyinfe/semi-ui';
import {ITEMS_PER_PAGE} from '../constants';
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
@@ -194,19 +206,16 @@ const LogsTable = () => {
}
return (
text.length > 10 ?
<>
{text.slice(0, 10)}
<Button
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
查看全部
</Button>
</>
: text
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
@@ -220,19 +229,16 @@ const LogsTable = () => {
}
return (
text.length > 10 ?
<>
{text.slice(0, 10)}
<Button
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
查看全部
</Button>
</>
: text
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
@@ -246,19 +252,16 @@ const LogsTable = () => {
}
return (
text.length > 10 ?
<>
{text.slice(0, 10)}
<Button
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
查看全部
</Button>
</>
: text
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
}
@@ -304,7 +307,7 @@ const LogsTable = () => {
// data.key = '' + data.id
setLogs(logs);
setLogCount(logs.length + ITEMS_PER_PAGE);
console.log(logCount);
// console.log(logCount);
}
const loadLogs = async (startIdx) => {
@@ -414,15 +417,11 @@ const LogsTable = () => {
>
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
</Modal>
{/* 模态框组件,用于展示图片 */}
<Modal
title="图片预览"
visible={isModalOpenurl}
onCancel={() => setIsModalOpenurl(false)}
footer={null} // 模态框不显示底部按钮
>
<img src={modalImageUrl} style={{ width: '100%' }} alt="结果图片" />
</Modal>
<ImagePreview
src={modalImageUrl}
visible={isModalOpenurl}
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
/>
</Layout>
</>

View File

@@ -3,13 +3,15 @@ import {Divider, Form, Grid, Header} from 'semantic-ui-react';
import {API, showError, showSuccess, timestamp2string, verifyJSON} from '../helpers';
const OperationSetting = () => {
let now = new Date();let [inputs, setInputs] = useState({
let now = new Date();
let [inputs, setInputs] = useState({
QuotaForNewUser: 0,
QuotaForInviter: 0,
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
ModelPrice: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
@@ -19,28 +21,32 @@ const OperationSetting = () => {
LogConsumeEnabled: '',
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
DataExportInterval: 5,
RetryTimes: 0
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
const getOptions = async () => {
const res = await API.get('/api/option/');
const { success, message, data } = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
const getOptions = async () => {
const res = await API.get('/api/option/');
const {success, message, data} = res.data;
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio' || item.key === 'ModelPrice') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
});
setInputs(newInputs);
setOriginInputs(newInputs);
} else {
showError(message);
}
newInputs[item.key] = item.value;
});
setInputs(newInputs);
setOriginInputs(newInputs);
} else {
showError(message);
}
};
};
useEffect(() => {
getOptions().then();
@@ -65,80 +71,88 @@ const OperationSetting = () => {
};
const handleInputChange = async (e, {name, value}) => {
if (name.endsWith('Enabled')) {
if (name.endsWith('Enabled') || name === 'DataExportInterval') {
await updateOption(name, value);
} else {
setInputs((inputs) => ({...inputs, [name]: value}));
}
};
const submitConfig = async (group) => {
switch (group) {
case 'monitor':
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
const submitConfig = async (group) => {
switch (group) {
case 'monitor':
if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
}
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
if (!verifyJSON(inputs.ModelPrice)) {
showError('模型固定价格不是合法的 JSON 字符串');
return;
}
await updateOption('ModelPrice', inputs.ModelPrice);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
}
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
break;
case 'general':
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
break;
}
if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
}
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
}
if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
}
if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
await updateOption('QuotaForInviter', inputs.QuotaForInviter);
}
if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
}
break;
case 'general':
if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink);
}
if (originInputs['ChatLink'] !== inputs.ChatLink) {
await updateOption('ChatLink', inputs.ChatLink);
}
if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
}
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
break;
}
};
};
const deleteHistoryLogs = async () => {
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
return;
}
showError('日志清理失败:' + message);
};return (
console.log(inputs);
const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
const {success, message, data} = res.data;
if (success) {
showSuccess(`${data} 条日志已清理!`);
return;
}
showError('日志清理失败:' + message);
};
return (
<Grid columns={1}>
<Grid.Column>
<Form loading={loading}>
@@ -200,31 +214,58 @@ const OperationSetting = () => {
name='DisplayTokenStatEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.DrawingEnabled === 'true'}
label='启用绘图功能'
name='DrawingEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Button onClick={() => {
submitConfig('general').then();
}}>保存通用设置</Form.Button><Divider />
<Header as='h3'>
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, { name, value }) => {
setHistoryTimestamp(value);
}} />
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
}}>保存通用设置</Form.Button><Divider/>
<Header as='h3'>
日志设置
</Header>
<Form.Group inline>
<Form.Checkbox
checked={inputs.LogConsumeEnabled === 'true'}
label='启用额度消费日志记录'
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.DataExportEnabled === 'true'}
label='启用数据看板(实验性)'
name='DataExportEnabled'
onChange={handleInputChange}
/>
<Form.Input
label='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
name='DataExportInterval'
type={'number'}
step='1'
min='1'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DataExportInterval}
placeholder='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
/>
</Form.Group>
<Divider/>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, {name, value}) => {
setHistoryTimestamp(value);
}}/>
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider/>
<Header as='h3'>
监控设置
@@ -315,6 +356,17 @@ const OperationSetting = () => {
<Header as='h3'>
倍率设置
</Header>
<Form.Group widths='equal'>
<Form.TextArea
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
name='ModelPrice'
onChange={handleInputChange}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.ModelPrice}
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1一次消耗0.1刀'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='模型倍率'

View File

@@ -20,6 +20,7 @@ import {
import {getQuotaPerUnit, renderQuota, renderQuotaWithPrompt, stringToColor} from "../helpers/render";
import EditToken from "../pages/Token/EditToken";
import EditUser from "../pages/User/EditUser";
import passwordResetConfirm from "./PasswordResetConfirm";
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -29,9 +30,12 @@ const PersonalSetting = () => {
wechat_verification_code: '',
email_verification_code: '',
email: '',
self_account_deletion_confirmation: ''
self_account_deletion_confirmation: '',
set_new_password: '',
set_new_password_confirmation: '',
});
const [status, setStatus] = useState({});
const [showChangePasswordModal, setShowChangePasswordModal] = useState(false);
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
const [showEmailBindModal, setShowEmailBindModal] = useState(false);
const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false);
@@ -180,6 +184,27 @@ const PersonalSetting = () => {
}
};
const changePassword = async () => {
if (inputs.set_new_password !== inputs.set_new_password_confirmation) {
showError('两次输入的密码不一致!');
return;
}
const res = await API.put(
`/api/user/self`,
{
password: inputs.set_new_password
}
);
const {success, message} = res.data;
if (success) {
showSuccess('密码修改成功!');
setShowWeChatBindModal(false);
} else {
showError(message);
}
setShowChangePasswordModal(false);
};
const transfer = async () => {
if (transferAmount < getQuotaPerUnit()) {
showError('划转金额最低为' + renderQuota(getQuotaPerUnit()));
@@ -420,6 +445,9 @@ const PersonalSetting = () => {
<Space>
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
<Button onClick={() => {
setShowChangePasswordModal(true);
}}>修改密码</Button>
<Button type={'danger'} onClick={() => {
setShowAccountDeleteModal(true);
}}>删除个人账户</Button>
</Space>
@@ -543,6 +571,39 @@ const PersonalSetting = () => {
)}
</div>
</Modal>
<Modal
onCancel={() => setShowChangePasswordModal(false)}
visible={showChangePasswordModal}
size={'small'}
centered={true}
onOk={changePassword}
>
<div style={{marginTop: 20}}>
<Input
name='set_new_password'
placeholder='新密码'
value={inputs.set_new_password}
onChange={(value)=>handleInputChange('set_new_password', value)}
/>
<Input
style={{marginTop: 20}}
name='set_new_password_confirmation'
placeholder='确认新密码'
value={inputs.set_new_password_confirmation}
onChange={(value)=>handleInputChange('set_new_password_confirmation', value)}
/>
{turnstileEnabled ? (
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
) : (
<></>
)}
</div>
</Modal>
</div>
</Layout.Content>

View File

@@ -394,7 +394,7 @@ const RedemptionsTable = () => {
}
let keys = "";
for (let i = 0; i < selectedKeys.length; i++) {
keys += selectedKeys[i].name + " sk-" + selectedKeys[i].key + "\n";
keys += selectedKeys[i].name + " " + selectedKeys[i].key + "\n";
}
await copyText(keys);
}

View File

@@ -1,4 +1,4 @@
import React, {useContext, useState} from 'react';
import React, {useContext, useMemo, useState} from 'react';
import {Link, useNavigate} from 'react-router-dom';
import {UserContext} from '../context/User';
@@ -6,7 +6,7 @@ import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../he
import '../index.css';
import {
IconAt,
IconCalendarClock,
IconHistogram,
IconGift,
IconKey,
@@ -21,78 +21,6 @@ import {
import {Nav, Avatar, Dropdown, Layout} from '@douyinfe/semi-ui';
// HeaderBar Buttons
let headerButtons = [
{
text: '首页',
itemKey: 'home',
to: '/',
icon: <IconHome/>
},
{
text: '渠道',
itemKey: 'channel',
to: '/channel',
icon: <IconLayers/>,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '聊天',
itemKey: 'chat',
to: '/chat',
icon: <IconComment />,
className: localStorage.getItem('chat_link')?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '令牌',
itemKey: 'token',
to: '/token',
icon: <IconKey/>
},
{
text: '兑换码',
itemKey: 'redemption',
to: '/redemption',
icon: <IconGift/>,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '钱包',
itemKey: 'topup',
to: '/topup',
icon: <IconCreditCard/>
},
{
text: '用户管理',
itemKey: 'user',
to: '/user',
icon: <IconUser/>,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '日志',
itemKey: 'log',
to: '/log',
icon: <IconHistogram/>
},
{
text: '绘图',
itemKey: 'midjourney',
to: '/midjourney',
icon: <IconImage/>
},
{
text: '设置',
itemKey: 'setting',
to: '/setting',
icon: <IconSetting/>
},
// {
// text: '关于',
// itemKey: 'about',
// to: '/about',
// icon: <IconAt/>
// }
];
const SiderBar = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -101,6 +29,87 @@ const SiderBar = () => {
const [showSidebar, setShowSidebar] = useState(false);
const systemName = getSystemName();
const logo = getLogo();
const headerButtons = useMemo(() => [
{
text: '首页',
itemKey: 'home',
to: '/',
icon: <IconHome/>
},
{
text: '渠道',
itemKey: 'channel',
to: '/channel',
icon: <IconLayers/>,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '聊天',
itemKey: 'chat',
to: '/chat',
icon: <IconComment />,
className: localStorage.getItem('chat_link')?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '令牌',
itemKey: 'token',
to: '/token',
icon: <IconKey/>
},
{
text: '兑换码',
itemKey: 'redemption',
to: '/redemption',
icon: <IconGift/>,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '钱包',
itemKey: 'topup',
to: '/topup',
icon: <IconCreditCard/>
},
{
text: '用户管理',
itemKey: 'user',
to: '/user',
icon: <IconUser/>,
className: isAdmin()?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '日志',
itemKey: 'log',
to: '/log',
icon: <IconHistogram/>
},
{
text: '数据看版',
itemKey: 'detail',
to: '/detail',
icon: <IconCalendarClock />,
className: localStorage.getItem('enable_data_export') === 'true'?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '绘图',
itemKey: 'midjourney',
to: '/midjourney',
icon: <IconImage/>,
className: localStorage.getItem('enable_drawing') === 'true'?'semi-navigation-item-normal':'tableHiddle',
},
{
text: '设置',
itemKey: 'setting',
to: '/setting',
icon: <IconSetting/>
},
// {
// text: '关于',
// itemKey: 'about',
// to: '/about',
// icon: <IconAt/>
// }
], [localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), localStorage.getItem('chat_link'), isAdmin()]);
async function logout() {
setShowSidebar(false);
@@ -133,6 +142,7 @@ const SiderBar = () => {
setting: "/setting",
about: "/about",
chat: "/chat",
detail: "/detail",
};
return (
<Link

View File

@@ -43,10 +43,14 @@ function renderTimestamp(timestamp) {
);
}
function renderStatus(status) {
function renderStatus(status, model_limits_enabled = false) {
switch (status) {
case 1:
return <Tag color='green' size='large'>已启用</Tag>;
if (model_limits_enabled) {
return <Tag color='green' size='large'>已启用限制模型</Tag>;
} else {
return <Tag color='green' size='large'>已启用</Tag>;
}
case 2:
return <Tag color='red' size='large'> 已禁用 </Tag>;
case 3:
@@ -78,7 +82,7 @@ const TokensTable = () => {
render: (text, record, index) => {
return (
<div>
{renderStatus(text)}
{renderStatus(text, record.model_limits_enabled)}
</div>
);
},
@@ -224,6 +228,11 @@ const TokensTable = () => {
const closeEdit = () => {
setShowEdit(false);
setTimeout(() => {
setEditingToken({
id: undefined,
});
}, 500);
}
const setTokensFormat = (tokens) => {

View File

@@ -72,32 +72,49 @@ const UsersTable = () => {
}, {
title: '状态', dataIndex: 'status', render: (text, record, index) => {
return (<div>
{renderStatus(text)}
{record.DeletedAt !== null? <Tag color='red'>已注销</Tag> : renderStatus(text)}
</div>);
},
}, {
title: '', dataIndex: 'operate', render: (text, record, index) => (<div>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'promote', record)
}}
>
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
</Popconfirm>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'demote', record)
}}
>
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
</Popconfirm>
{
record.DeletedAt !== null ? <></>:
<>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'promote', record)
}}
>
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
</Popconfirm>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'demote', record)
}}
>
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
</Popconfirm>
{record.status === 1 ?
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'disable', record)
}}>禁用</Button> :
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'enable', record);
}} disabled={record.status === 3}>启用</Button>}
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
setEditingUser(record);
setShowEditUser(true);
}}>编辑</Button>
</>
}
<Popconfirm
title="确定是否要删除此用户?"
content="此修改将不可逆"
content="硬删除,此修改将不可逆"
okType={'danger'}
position={'left'}
onConfirm={() => {
@@ -108,17 +125,6 @@ const UsersTable = () => {
>
<Button theme='light' type='danger' style={{marginRight: 1}}>删除</Button>
</Popconfirm>
{record.status === 1 ?
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'disable', record)
}}>禁用</Button> :
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'enable', record);
}} disabled={record.status === 3}>启用</Button>}
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
setEditingUser(record);
setShowEditUser(true);
}}>编辑</Button>
</div>),
},];

View File

@@ -1,16 +1,17 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI' },
{ key: 24, text: 'Midjourney Proxy', value: 24, color: 'light-blue', label: 'Midjourney Proxy' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue', label: '知识库FastGPT' },
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple', label: '知识库:AI Proxy' },
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI'},
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange', label: 'Google Gemini'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问'},
{key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知'},
{key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM'},
{key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑'},
{key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元'},
{key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道'},
{key: 22, text: '知识库:FastGPT', value: 22, color: 'blue', label: '知识库:FastGPT'},
{key: 21, text: '知识库AI Proxy', value: 21, color: 'purple', label: '知识库AI Proxy'},
];

View File

@@ -1,114 +1,129 @@
import { Label } from 'semantic-ui-react';
import {Label} from 'semantic-ui-react';
import {Tag} from "@douyinfe/semi-ui";
export function renderText(text, limit) {
if (text.length > limit) {
return text.slice(0, limit - 3) + '...';
}
return text;
if (text.length > limit) {
return text.slice(0, limit - 3) + '...';
}
return text;
}
export function renderGroup(group) {
if (group === '') {
return <Tag size='large'>default</Tag>;
}
let groups = group.split(',');
groups.sort();
return <>
{groups.map((group) => {
if (group === 'vip' || group === 'pro') {
return <Tag size='large' color='yellow'>{group}</Tag>;
} else if (group === 'svip' || group === 'premium') {
return <Tag size='large' color='red'>{group}</Tag>;
}
if (group === 'default') {
return <Tag size='large'>{group}</Tag>;
} else {
return <Tag size='large' color={stringToColor(group)}>{group}</Tag>;
}
})}
</>;
if (group === '') {
return <Tag size='large'>default</Tag>;
}
let groups = group.split(',');
groups.sort();
return <>
{groups.map((group) => {
if (group === 'vip' || group === 'pro') {
return <Tag size='large' color='yellow'>{group}</Tag>;
} else if (group === 'svip' || group === 'premium') {
return <Tag size='large' color='red'>{group}</Tag>;
}
if (group === 'default') {
return <Tag size='large'>{group}</Tag>;
} else {
return <Tag size='large' color={stringToColor(group)}>{group}</Tag>;
}
})}
</>;
}
export function renderNumber(num) {
if (num >= 1000000000) {
return (num / 1000000000).toFixed(1) + 'B';
} else if (num >= 1000000) {
return (num / 1000000).toFixed(1) + 'M';
} else if (num >= 10000) {
return (num / 1000).toFixed(1) + 'k';
} else {
if (num >= 1000000000) {
return (num / 1000000000).toFixed(1) + 'B';
} else if (num >= 1000000) {
return (num / 1000000).toFixed(1) + 'M';
} else if (num >= 10000) {
return (num / 1000).toFixed(1) + 'k';
} else {
return num;
}
}
export function renderQuotaNumberWithDigit(num, digits = 2) {
let displayInCurrency = localStorage.getItem('display_in_currency');
num = num.toFixed(digits);
if (displayInCurrency) {
return '$' + num;
}
return num;
}
}
export function renderNumberWithPoint(num) {
num = num.toFixed(2);
if (num >= 100000) {
// Convert number to string to manipulate it
let numStr = num.toString();
// Find the position of the decimal point
let decimalPointIndex = numStr.indexOf('.');
num = num.toFixed(2);
if (num >= 100000) {
// Convert number to string to manipulate it
let numStr = num.toString();
// Find the position of the decimal point
let decimalPointIndex = numStr.indexOf('.');
let wholePart = numStr;
let decimalPart = '';
let wholePart = numStr;
let decimalPart = '';
// If there is a decimal point, split the number into whole and decimal parts
if (decimalPointIndex !== -1) {
wholePart = numStr.slice(0, decimalPointIndex);
decimalPart = numStr.slice(decimalPointIndex);
// If there is a decimal point, split the number into whole and decimal parts
if (decimalPointIndex !== -1) {
wholePart = numStr.slice(0, decimalPointIndex);
decimalPart = numStr.slice(decimalPointIndex);
}
// Take the first two and last two digits of the whole number part
let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2);
// Return the formatted number
return shortenedWholePart + decimalPart;
}
// Take the first two and last two digits of the whole number part
let shortenedWholePart = wholePart.slice(0, 2) + '..' + wholePart.slice(-2);
// Return the formatted number
return shortenedWholePart + decimalPart;
}
// If the number is less than 100,000, return it unmodified
return num;
// If the number is less than 100,000, return it unmodified
return num;
}
export function getQuotaPerUnit() {
let quotaPerUnit = localStorage.getItem('quota_per_unit');
quotaPerUnit = parseFloat(quotaPerUnit);
return quotaPerUnit;
let quotaPerUnit = localStorage.getItem('quota_per_unit');
quotaPerUnit = parseFloat(quotaPerUnit);
return quotaPerUnit;
}
export function getQuotaWithUnit(quota, digits = 6) {
let quotaPerUnit = localStorage.getItem('quota_per_unit');
quotaPerUnit = parseFloat(quotaPerUnit);
return (quota / quotaPerUnit).toFixed(digits);
}
export function renderQuota(quota, digits = 2) {
let quotaPerUnit = localStorage.getItem('quota_per_unit');
let displayInCurrency = localStorage.getItem('display_in_currency');
quotaPerUnit = parseFloat(quotaPerUnit);
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return '$' + (quota / quotaPerUnit).toFixed(digits);
}
return renderNumber(quota);
let quotaPerUnit = localStorage.getItem('quota_per_unit');
let displayInCurrency = localStorage.getItem('display_in_currency');
quotaPerUnit = parseFloat(quotaPerUnit);
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return '$' + (quota / quotaPerUnit).toFixed(digits);
}
return renderNumber(quota);
}
export function renderQuotaWithPrompt(quota, digits) {
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return `(等价金额:${renderQuota(quota, digits)}`;
}
return '';
let displayInCurrency = localStorage.getItem('display_in_currency');
displayInCurrency = displayInCurrency === 'true';
if (displayInCurrency) {
return `(等价金额:${renderQuota(quota, digits)}`;
}
return '';
}
const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
'light-blue', 'lime', 'orange', 'pink',
'purple', 'red', 'teal', 'violet', 'yellow'
'light-blue', 'lime', 'orange', 'pink',
'purple', 'red', 'teal', 'violet', 'yellow'
]
export function stringToColor(str) {
let sum = 0;
// 对字符串中的每个字符进行操作
for (let i = 0; i < str.length; i++) {
// 将字符的ASCII值加到sum中
sum += str.charCodeAt(i);
}
// 使用模运算得到个位数
let i = sum % colors.length;
return colors[i];
let sum = 0;
// 对字符串中的每个字符进行操作
for (let i = 0; i < str.length; i++) {
// 将字符的ASCII值加到sum中
sum += str.charCodeAt(i);
}
// 使用模运算得到个位数
let i = sum % colors.length;
return colors[i];
}

View File

@@ -171,6 +171,32 @@ export function timestamp2string(timestamp) {
);
}
export function timestamp2string1(timestamp) {
let date = new Date(timestamp * 1000);
// let year = date.getFullYear().toString();
let month = (date.getMonth() + 1).toString();
let day = date.getDate().toString();
let hour = date.getHours().toString();
if (month.length === 1) {
month = '0' + month;
}
if (day.length === 1) {
day = '0' + day;
}
if (hour.length === 1) {
hour = '0' + hour;
}
return (
// year +
// '-' +
month +
'-' +
day +
' ' +
hour + ":00"
);
}
export function downloadTextAsFile(text, filename) {
let blob = new Blob([text], { type: 'text/plain;charset=utf-8' });
let url = URL.createObjectURL(blob);

View File

@@ -1,7 +1,8 @@
import { initVChartSemiTheme } from '@visactor/vchart-semi-theme';
import VChart from "@visactor/vchart";
import React from 'react';
import ReactDOM from 'react-dom/client';
import {BrowserRouter} from 'react-router-dom';
import {Container} from 'semantic-ui-react';
import App from './App';
import HeaderBar from './components/HeaderBar';
import Footer from './components/Footer';
@@ -14,6 +15,11 @@ import {StatusProvider} from './context/Status';
import {Layout} from "@douyinfe/semi-ui";
import SiderBar from "./components/SiderBar";
// initialization
initVChartSemiTheme({
isWatchingThemeSwitch: true,
});
const root = ReactDOM.createRoot(document.getElementById('root'));
const {Sider, Content, Header} = Layout;
root.render(

View File

@@ -1,5 +1,4 @@
import React, { useEffect, useState } from 'react';
import { Header, Segment } from 'semantic-ui-react';
import { API, showError } from '../../helpers';
import { marked } from 'marked';
import {Layout} from "@douyinfe/semi-ui";

View File

@@ -86,6 +86,9 @@ const EditChannel = (props) => {
case 23:
localModels = ['hunyuan'];
break;
case 24:
localModels = ['gemini-pro'];
break;
}
setInputs((inputs) => ({...inputs, models: localModels}));
}

View File

@@ -0,0 +1,292 @@
import React, {useEffect, useRef, useState} from 'react';
import {Button, Col, Form, Layout, Row, Spin} from "@douyinfe/semi-ui";
import VChart from '@visactor/vchart';
import {useEffectOnce} from "usehooks-ts";
import {API, isAdmin, showError, timestamp2string, timestamp2string1} from "../../helpers";
import {getQuotaWithUnit, renderNumber, renderQuotaNumberWithDigit} from "../../helpers/render";
const Detail = (props) => {
let now = new Date();
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
});
const {username, token_name, model_name, start_timestamp, end_timestamp, channel} = inputs;
const isAdminUser = isAdmin();
const initialized = useRef(false)
const [modelDataChart, setModelDataChart] = useState(null);
const [modelDataPieChart, setModelDataPieChart] = useState(null);
const [loading, setLoading] = useState(false);
const [quotaData, setQuotaData] = useState([]);
const [quotaDataPie, setQuotaDataPie] = useState([]);
const [quotaDataLine, setQuotaDataLine] = useState([]);
const handleInputChange = (value, name) => {
setInputs((inputs) => ({...inputs, [name]: value}));
};
const spec_line = {
type: 'bar',
data: [
{
id: 'barData',
values: [
]
}
],
xField: 'Time',
yField: 'Usage',
seriesField: 'Model',
stack: true,
legends: {
visible: true
},
title: {
visible: true,
text: '模型消耗分布(小时)'
},
bar: {
// The state style of bar
state: {
hover: {
stroke: '#000',
lineWidth: 1
}
}
},
tooltip: {
mark: {
content: [
{
key: datum => datum['Model'],
value: datum => renderQuotaNumberWithDigit(datum['Usage'], 4)
}
]
},
dimension: {
content: [
{
key: datum => datum['Model'],
value: datum => datum['Usage']
}
],
updateContent: array => {
// sort by value
array.sort((a, b) => b.value - a.value);
// add $
for (let i = 0; i < array.length; i++) {
array[i].value = renderQuotaNumberWithDigit(array[i].value, 4);
}
return array;
}
}
}
};
const spec_pie = {
type: 'pie',
data: [
{
id: 'id0',
values: [
{ type: 'null', value: '0' },
]
}
],
outerRadius: 0.8,
innerRadius: 0.5,
padAngle: 0.6,
valueField: 'value',
categoryField: 'type',
pie: {
style: {
cornerRadius: 10
},
state: {
hover: {
outerRadius: 0.85,
stroke: '#000',
lineWidth: 1
},
selected: {
outerRadius: 0.85,
stroke: '#000',
lineWidth: 1
}
}
},
title: {
visible: true,
text: '模型调用次数占比'
},
legends: {
visible: true,
orient: 'left'
},
label: {
visible: true
},
tooltip: {
mark: {
content: [
{
key: datum => datum['type'],
value: datum => renderNumber(datum['value'])
}
]
}
}
};
const loadQuotaData = async (lineChart, pieChart) => {
setLoading(true);
let url = '';
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/data/?username=${username}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} else {
url = `/api/data/self/?start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const {success, message, data} = res.data;
if (success) {
setQuotaData(data);
if (data.length === 0) {
data.push({
'count': 0,
'model_name': '无数据',
'quota': 0,
'created_at': now.getTime() / 1000
})
}
updateChart(lineChart, pieChart, data);
} else {
showError(message);
}
setLoading(false);
};
const refresh = async () => {
await loadQuotaData(modelDataChart, modelDataPieChart);
};
const initChart = async () => {
let lineChart = modelDataChart
if (!modelDataChart) {
lineChart = new VChart(spec_line, {dom: 'model_data'});
setModelDataChart(lineChart);
lineChart.renderAsync();
}
let pieChart = modelDataPieChart
if (!modelDataPieChart) {
pieChart = new VChart(spec_pie, {dom: 'model_pie'});
setModelDataPieChart(pieChart);
pieChart.renderAsync();
}
console.log('init vchart');
await loadQuotaData(lineChart, pieChart)
}
const updateChart = (lineChart, pieChart, data) => {
if (isAdminUser) {
// 将所有用户合并
}
let pieData = [];
let lineData = [];
for (let i = 0; i < data.length; i++) {
const item = data[i];
// 合并model_name
let pieItem = pieData.find(it => it.type === item.model_name);
if (pieItem) {
pieItem.value += item.count;
} else {
pieData.push({
"type": item.model_name,
"value": item.count
});
}
// 合并created_at和model_name 为 lineData, created_at 数据类型是小时的时间戳
// 转换日期格式
let createTime = timestamp2string1(item.created_at);
let lineItem = lineData.find(it => it.Time === createTime && it.Model === item.model_name);
if (lineItem) {
lineItem.Usage += parseFloat(getQuotaWithUnit(item.quota));
} else {
lineData.push({
"Time": createTime,
"Model": item.model_name,
"Usage": parseFloat(getQuotaWithUnit(item.quota))
});
}
}
// sort by count
pieData.sort((a, b) => b.value - a.value);
pieChart.updateData('id0', pieData);
lineChart.updateData('barData', lineData);
pieChart.reLayout();
lineChart.reLayout();
}
useEffect(() => {
if (!initialized.current) {
initialized.current = true;
initChart();
}
}, []);
return (
<>
<Layout>
<Layout.Header>
<h3>数据看板</h3>
</Layout.Header>
<Layout.Content>
<Form layout='horizontal' style={{marginTop: 10}}>
<>
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
initValue={start_timestamp}
value={start_timestamp} type='dateTime'
name='start_timestamp'
onChange={value => handleInputChange(value, 'start_timestamp')}/>
<Form.DatePicker field="end_timestamp" fluid label='结束时间' style={{width: 272}}
initValue={end_timestamp}
value={end_timestamp} type='dateTime'
name='end_timestamp'
onChange={value => handleInputChange(value, 'end_timestamp')}/>
{
isAdminUser && <>
<Form.Input field="username" label='用户名称' style={{width: 176}} value={username}
placeholder={'可选值'} name='username'
onChange={value => handleInputChange(value, 'username')}/>
</>
}
<Form.Section>
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
onClick={refresh} loading={loading}>查询</Button>
</Form.Section>
</>
</Form>
<Spin spinning={loading}>
<div style={{height: 500}}>
<div id="model_pie" style={{width: '100%', minWidth: 100}}></div>
</div>
<div style={{height: 500}}>
<div id="model_data" style={{width: '100%', minWidth: 100}}></div>
</div>
</Spin>
</Layout.Content>
</Layout>
</>
);
};
export default Detail;

View File

@@ -2,22 +2,37 @@ import React, {useEffect, useRef, useState} from 'react';
import {useParams, useNavigate} from 'react-router-dom';
import {API, isMobile, showError, showSuccess, timestamp2string} from '../../helpers';
import {renderQuota, renderQuotaWithPrompt} from '../../helpers/render';
import {Layout, SideSheet, Button, Space, Spin, Banner, Input, DatePicker, AutoComplete, Typography} from "@douyinfe/semi-ui";
import {
Layout,
SideSheet,
Button,
Space,
Spin,
Banner,
Input,
DatePicker,
AutoComplete,
Typography,
Checkbox, Select
} from "@douyinfe/semi-ui";
import Title from "@douyinfe/semi-ui/lib/es/typography/title";
import {Divider} from "semantic-ui-react";
const EditToken = (props) => {
const isEdit = props.editingToken.id !== undefined;
const [isEdit, setIsEdit] = useState(false);
const [loading, setLoading] = useState(isEdit);
const originInputs = {
name: '',
remain_quota: isEdit ? 0 : 500000,
expired_time: -1,
unlimited_quota: false
unlimited_quota: false,
model_limits_enabled: false,
model_limits: [],
};
const [inputs, setInputs] = useState(originInputs);
const {name, remain_quota, expired_time, unlimited_quota} = inputs;
const {name, remain_quota, expired_time, unlimited_quota, model_limits_enabled, model_limits} = inputs;
// const [visible, setVisible] = useState(false);
const [models, setModels] = useState({});
const navigate = useNavigate();
const handleInputChange = (name, value) => {
setInputs((inputs) => ({...inputs, [name]: value}));
@@ -44,6 +59,20 @@ const EditToken = (props) => {
setInputs({...inputs, unlimited_quota: !unlimited_quota});
};
const loadModels = async () => {
let res = await API.get(`/api/user/models`);
const {success, message, data} = res.data;
if (success) {
let localModelOptions = data.map((model) => ({
label: model,
value: model
}));
setModels(localModelOptions);
} else {
showError(message);
}
}
const loadToken = async () => {
setLoading(true);
let res = await API.get(`/api/token/${props.editingToken.id}`);
@@ -52,6 +81,11 @@ const EditToken = (props) => {
if (data.expired_time !== -1) {
data.expired_time = timestamp2string(data.expired_time);
}
if (data.model_limits !== '') {
data.model_limits = data.model_limits.split(',');
} else {
data.model_limits = [];
}
setInputs(data);
} else {
showError(message);
@@ -59,17 +93,22 @@ const EditToken = (props) => {
setLoading(false);
};
useEffect(() => {
if (isEdit) {
loadToken().then(
() => {
// console.log(inputs);
}
);
} else {
setInputs(originInputs);
}
setIsEdit(props.editingToken.id !== undefined);
}, [props.editingToken.id]);
useEffect(() => {
if (!isEdit) {
setInputs(originInputs);
} else {
loadToken().then(
() => {
// console.log(inputs);
}
);
}
loadModels();
}, [isEdit]);
// 新增 state 变量 tokenCount 来记录用户想要创建的令牌数量,默认为 1
const [tokenCount, setTokenCount] = useState(1);
@@ -107,7 +146,7 @@ const EditToken = (props) => {
}
localInputs.expired_time = Math.ceil(time / 1000);
}
localInputs.model_limits = localInputs.model_limits.join(',');
let res = await API.put(`/api/token/`, {...localInputs, id: parseInt(props.editingToken.id)});
const {success, message} = res.data;
if (success) {
@@ -137,7 +176,7 @@ const EditToken = (props) => {
}
localInputs.expired_time = Math.ceil(time / 1000);
}
localInputs.model_limits = localInputs.model_limits.join(',');
let res = await API.post(`/api/token/`, localInputs);
const {success, message} = res.data;
@@ -234,7 +273,7 @@ const EditToken = (props) => {
value={remain_quota}
autoComplete='new-password'
type='number'
position={'top'}
// position={'top'}
data={[
{value: 500000, label: '1$'},
{value: 5000000, label: '10$'},
@@ -245,27 +284,30 @@ const EditToken = (props) => {
]}
disabled={unlimited_quota}
/>
<div style={{marginTop: 20}}>
<Typography.Text>新建数量</Typography.Text>
</div>
{!isEdit && (
<AutoComplete
style={{ marginTop: 8 }}
label='数量'
placeholder={'请选择或输入创建令牌的数量'}
onChange={(value) => handleTokenCountChange(value)}
onSelect={(value) => handleTokenCountChange(value)}
value={tokenCount.toString()}
autoComplete='off'
type='number'
data={[
{ value: 10, label: '10个' },
{ value: 20, label: '20个' },
{ value: 30, label: '30个' },
{ value: 100, label: '100个' },
]}
disabled={unlimited_quota}
/>
<>
<div style={{marginTop: 20}}>
<Typography.Text>新建数量</Typography.Text>
</div>
<AutoComplete
style={{ marginTop: 8 }}
label='数量'
placeholder={'请选择或输入创建令牌的数量'}
onChange={(value) => handleTokenCountChange(value)}
onSelect={(value) => handleTokenCountChange(value)}
value={tokenCount.toString()}
autoComplete='off'
type='number'
data={[
{ value: 10, label: '10个' },
{ value: 20, label: '20个' },
{ value: 30, label: '30个' },
{ value: 100, label: '100个' },
]}
disabled={unlimited_quota}
/>
</>
)}
<div>
@@ -273,6 +315,34 @@ const EditToken = (props) => {
setUnlimitedQuota();
}}>{unlimited_quota ? '取消无限额度' : '设为无限额度'}</Button>
</div>
<Divider/>
<div style={{marginTop: 10, display: 'flex'}}>
<Space>
<Checkbox
name='model_limits_enabled'
checked={model_limits_enabled}
onChange={(e) => handleInputChange('model_limits_enabled', e.target.checked)}
>
</Checkbox>
<Typography.Text>启用模型限制非必要不建议启用</Typography.Text>
</Space>
</div>
<Select
style={{marginTop: 8}}
placeholder={'请选择该渠道所支持的模型'}
name='models'
required
multiple
selection
onChange={value => {
handleInputChange('model_limits', value);
}}
value={inputs.model_limits}
autoComplete='new-password'
optionList={models}
disabled={!model_limits_enabled}
/>
</Spin>
</SideSheet>
</>

View File

@@ -74,12 +74,17 @@ const TopUp = () => {
const {message, data} = res.data;
// showInfo(message);
if (message === 'success') {
let params = data
let url = res.data.url
let form = document.createElement('form')
form.action = url
form.method = 'POST'
form.target = '_blank'
// 判断是否为safari浏览器
let isSafari = navigator.userAgent.indexOf("Safari") > -1 && navigator.userAgent.indexOf("Chrome") < 1;
if (!isSafari) {
form.target = '_blank'
}
for (let key in params) {
let input = document.createElement('input')
input.type = 'hidden'
@@ -126,7 +131,7 @@ const TopUp = () => {
}, []);
const renderAmount = () => {
console.log(amount);
// console.log(amount);
return amount + '元';
}