Compare commits

...

35 Commits

Author SHA1 Message Date
Xyfacai
d3e070d963 fix: mj 错误处理 2024-01-14 16:17:45 +08:00
CaIon
e688e41360 perf: 美化绘画界面UI 2024-01-13 22:38:50 +08:00
CaIon
e41fcd563e chore: 数据看板限制用户只能查询跨度一个月的数据 2024-01-13 01:32:23 +08:00
CaIon
a8715c61c8 fix: 修复数据看板渲染错误 2024-01-13 01:27:50 +08:00
CaIon
00306aa142 perf: 数据看板支持选择时间粒度 2024-01-13 00:33:52 +08:00
CaIon
d30b9321b2 chore: UpdateMidjourneyTaskBulk with SafeGoroutine 2024-01-12 18:38:45 +08:00
CaIon
2b3539fcaa fix: 修复渠道一致性问题 2024-01-12 14:57:03 +08:00
CaIon
e2a1caba4c fix: 修复渠道一致性问题 2024-01-12 14:36:15 +08:00
CaIon
febcadb42c chore: add SafeGoroutine 2024-01-12 13:49:53 +08:00
CaIon
075b1ac113 fix: delete RelayPanicRecover 2024-01-12 13:48:13 +08:00
CaIon
6757166de5 Merge remote-tracking branch 'origin/main' 2024-01-12 13:46:08 +08:00
CaIon
2a9c3ac6af fix: 修复mj错误返还费用问题 2024-01-12 13:45:52 +08:00
Calcium-Ion
2ccd6c04e0 Merge pull request #45 from AI-ASS/patch-1
fix: the Redis problem in the CacheGetUsername function
2024-01-12 13:17:20 +08:00
GAI Group
6aadcf0c78 fix: the Redis problem in the CacheGetUsername function
fix: the Redis problem in the CacheGetUsername function
2024-01-12 13:15:43 +08:00
CaIon
312417f393 update README.md 2024-01-12 00:35:22 +08:00
CaIon
ad85792818 fix: user group size 2024-01-11 22:06:30 +08:00
CaIon
521ef5e219 fix: token model limit 2024-01-11 19:18:48 +08:00
CaIon
f07b9f8ab2 chore: cache username 2024-01-11 18:40:44 +08:00
CaIon
64e9e9cc20 fix: 完善令牌预扣费逻辑 2024-01-11 14:12:48 +08:00
CaIon
701a28d0da fix: fix relay openai panic 2024-01-10 19:26:11 +08:00
CaIon
d04d2a6c4d perf: UI美化 2024-01-10 18:32:44 +08:00
CaIon
e2317524f9 perf: 美化数据看板 2024-01-10 17:49:55 +08:00
CaIon
a3b726dd82 fix: 修复高并发下,高额度用户使用低额度令牌没有预扣费的问题 2024-01-10 14:23:23 +08:00
CaIon
042d55cfd3 fix: fix response choice json 2024-01-10 13:57:49 +08:00
CaIon
3432d9e0f6 fix: do not consume user quota if failed 2024-01-10 13:56:29 +08:00
CaIon
aca8d25372 feat: able to change gemini safety setting 2024-01-10 13:56:10 +08:00
CaIon
6a24e8953f feat: able to fix channels 2024-01-10 13:23:43 +08:00
CaIon
3e13810ca2 fix: fix channel panic 2024-01-10 00:17:03 +08:00
CaIon
4a4df75830 docs: update Midjourney.md 2024-01-09 20:07:38 +08:00
CaIon
db29c96361 fix: add Tools support 2024-01-09 16:44:34 +08:00
CaIon
1618a8c9fc perf: 优化数据看板性能 2024-01-09 16:20:04 +08:00
CaIon
75b6327f4f feat: support Azure dall-e 2024-01-09 15:46:45 +08:00
CaIon
fb95216b5a chore: delete model price log 2024-01-09 12:11:09 +08:00
CaIon
6ed365c267 fix: 修复数据看板筛选用户的时候不能指定时间的问题 2024-01-09 12:11:09 +08:00
CaIon
4f95b7d049 update README.md 2024-01-09 12:11:08 +08:00
40 changed files with 961 additions and 521 deletions

View File

@@ -1,254 +1,291 @@
# Midjourney Proxy API文档
**简介**:Midjourney Proxy API文档
## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
**HOST**:https://api.nekoedu.com
```json
{
"gpt-4-gizmo-*": 0.1,
"mj_imagine": 0.1,
"mj_variation": 0.1,
"mj_reroll": 0.1,
"mj_blend": 0.1,
"mj_describe": 0.05,
"mj_upscale": 0.05
}
```
## 渠道设置
**Version**:v2.3.5
### 对接 midjourney-proxy
1. 部署Midjourney-Proxy并配置好midjourney账号等强烈建议设置密钥[项目地址](https://github.com/novicezk/midjourney-proxy)
2. 在渠道管理中添加渠道渠道类型选择Midjourney Proxy模型选择midjourney
3. 地址填写midjourney-proxy部署的地址例如http://localhost:8080
4. 密钥填写midjourney-proxy的密钥如果没有设置密钥可以随便填
### 对接上游new api
1. 在渠道管理中添加渠道渠道类型选择Midjourney Proxy模型选择midjourney
2. 地址填写上游new api的地址例如http://localhost:8080
3. 密钥填写上游new api的密钥
[TOC]
# 任务提交
## 绘图变化
## 任务提交
### 绘图变化
**接口地址**:`/mj/submit/change`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"action": "UPSCALE",
"index": 1,
"notifyHook": "",
"state": "",
"taskId": "1320098173412546"
"action"
:
"UPSCALE",
"index"
:
1,
"notifyHook"
:
"",
"state"
:
"",
"taskId"
:
"1320098173412546"
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|changeDTO|changeDTO|body|true|变化任务提交参数|变化任务提交参数|
|  action|UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL||true|string||
|  index|序号(1~4), action为UPSCALE,VARIATION时必传||false|integer(int32)||
|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
|  state|自定义参数||false|string||
|  taskId|任务ID||true|string||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------------------------|-----------------------------------------------------------------------|------|-------|----------------|----------|
| changeDTO | changeDTO | body | true | 变化任务提交参数 | 变化任务提交参数 |
|   action | UPSCALE(放大); VARIATION(变换); REROLL(重新生成),可用值:UPSCALE,VARIATION,REROLL | | true | string | |
|   index | 序号(1~4), action为UPSCALE,VARIATION时必传 | | false | integer(int32) | |
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|   state | 自定义参数 | | false | string | |
|   taskId | 任务ID | | true | string | |
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|提交结果|
|201|Created||
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 提交结果 |
| 201 | Created | |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
|description|描述|string||
|properties|扩展字段|object||
|result|任务ID|string||
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|-------------------------------------------|----------------|----------------|
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
| description | 描述 | string | |
| properties | 扩展字段 | object | |
| result | 任务ID | string | |
**响应示例**:
```javascript
{
"code": 1,
"description": "提交成功",
"properties": {},
"result": 1320098173412546
"code"
:
1,
"description"
:
"提交成功",
"properties"
:
{
}
,
"result"
:
1320098173412546
}
```
## 提交Imagine任务
### 提交Imagine任务
**接口地址**:`/mj/submit/imagine`
**请求方式**:`POST`
**请求数据类型**:`application/json`
**响应数据类型**:`*/*`
**接口描述**:
**请求示例**:
```javascript
{
"base64": "",
"notifyHook": "",
"prompt": "Cat",
"state": ""
"base64"
:
"",
"notifyHook"
:
"",
"prompt"
:
"Cat",
"state"
:
""
}
```
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|imagineDTO|imagineDTO|body|true|Imagine提交参数|Imagine提交参数|
|  base64|垫图base64||false|string||
|  notifyHook|回调地址, 为空时使用全局notifyHook||false|string||
|  prompt|提示词||true|string||
|  state|自定义参数||false|string||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------------------------|-------------------------|------|-------|-------------|-------------|
| imagineDTO | imagineDTO | body | true | Imagine提交参数 | Imagine提交参数 |
|   base64 | 垫图base64 | | false | string | |
|   notifyHook | 回调地址, 为空时使用全局notifyHook | | false | string | |
|   prompt | 提示词 | | true | string | |
|   state | 自定义参数 | | false | string | |
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|提交结果|
|201|Created||
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 提交结果 |
| 201 | Created | |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|code|状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)|integer(int32)|integer(int32)|
|description|描述|string||
|properties|扩展字段|object||
|result|任务ID|string||
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|-------------------------------------------|----------------|----------------|
| code | 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误) | integer(int32) | integer(int32) |
| description | 描述 | string | |
| properties | 扩展字段 | object | |
| result | 任务ID | string | |
**响应示例**:
```javascript
{
"code": 1,
"description": "提交成功",
"properties": {},
"result": 1320098173412546
"code"
:
1,
"description"
:
"提交成功",
"properties"
:
{
}
,
"result"
:
1320098173412546
}
```
## 任务查询
# 任务查询
## 指定ID获取任务
### 指定ID获取任务
**接口地址**:`/mj/task/{id}/fetch`
**请求方式**:`GET`
**请求数据类型**:`application/x-www-form-urlencoded`
**响应数据类型**:`*/*`
**接口描述**:
**请求参数**:
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
| -------- | -------- | ----- | -------- | -------- | ------ |
|id|任务ID|path|false|string||
| 参数名称 | 参数说明 | 请求类型 | 是否必须 | 数据类型 | schema |
|------|------|------|-------|--------|--------|
| id | 任务ID | path | false | string | |
**响应状态**:
| 状态码 | 说明 | schema |
| -------- | -------- | ----- |
|200|OK|任务|
|401|Unauthorized||
|403|Forbidden||
|404|Not Found||
| 状态码 | 说明 | schema |
|-----|--------------|--------|
| 200 | OK | 任务 |
| 401 | Unauthorized | |
| 403 | Forbidden | |
| 404 | Not Found | |
**响应参数**:
| 参数名称 | 参数说明 | 类型 | schema |
| -------- | -------- | ----- |----- |
|action|可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND|string||
|description|任务描述|string||
|failReason|失败原因|string||
|finishTime|结束时间|integer(int64)|integer(int64)|
|id|任务ID|string||
|imageUrl|图片url|string||
|progress|任务进度|string||
|prompt|提示词|string||
|promptEn|提示词-英文|string||
|startTime|开始执行时间|integer(int64)|integer(int64)|
|state|自定义参数|string||
|status|任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS|string||
|submitTime|提交时间|integer(int64)|integer(int64)|
| 参数名称 | 参数说明 | 类型 | schema |
|-------------|----------------------------------------------------------|----------------|----------------|
| action | 可用值:IMAGINE,UPSCALE,VARIATION,REROLL,DESCRIBE,BLEND | string | |
| description | 任务描述 | string | |
| failReason | 失败原因 | string | |
| finishTime | 结束时间 | integer(int64) | integer(int64) |
| id | 任务ID | string | |
| imageUrl | 图片url | string | |
| progress | 任务进度 | string | |
| prompt | 提示词 | string | |
| promptEn | 提示词-英文 | string | |
| startTime | 开始执行时间 | integer(int64) | integer(int64) |
| state | 自定义参数 | string | |
| status | 任务状态,可用值:NOT_START,SUBMITTED,IN_PROGRESS,FAILURE,SUCCESS | string | |
| submitTime | 提交时间 | integer(int64) | integer(int64) |
**响应示例**:
```javascript
{
"action": "",
"description": "",
"failReason": "",
"finishTime": 0,
"id": "",
"imageUrl": "",
"progress": "",
"prompt": "",
"promptEn": "",
"startTime": 0,
"state": "",
"status": "",
"submitTime": 0
"action"
:
"",
"description"
:
"",
"failReason"
:
"",
"finishTime"
:
0,
"id"
:
"",
"imageUrl"
:
"",
"progress"
:
"",
"prompt"
:
"",
"promptEn"
:
"",
"startTime"
:
0,
"state"
:
"",
"status"
:
"",
"submitTime"
:
0
}
```

View File

@@ -23,6 +23,7 @@
+ [x] /mj/submit/describe
+ [x] /mj/image/{id} (通过此接口获取图片,**请必须在系统设置中填写服务器地址!!**
+ [x] /mj/task/{id}/fetch 此接口返回的图片地址为经过One API转发的地址
+ [x] /task/list-by-condition
3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口:
+ [x] 易支付
4. 支持用key查询使用额度:
@@ -47,6 +48,18 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
### 使用宝塔面板Docker功能部署
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
# 注意数据库要开启远程访问并且只允许服务器IP访问
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest
# 注意数据库要开启远程访问并且只允许服务器IP访问
```
## Midjourney接口设置文档
[对接文档](Midjourney.md)
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">

View File

@@ -26,7 +26,8 @@ var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var DrawingEnabled = true
var DataExportEnabled = true
var DataExportInterval = 5 // unit: minute
var DataExportInterval = 5 // unit: minute
var DataExportDefaultTime = "hour" // unit: minute
// Any options with "Secret", "Token" in its key won't be return by GetOptions
@@ -104,6 +105,8 @@ var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
const (
RequestIdKey = "X-Oneapi-Request-Id"
)

32
common/go-channel.go Normal file
View File

@@ -0,0 +1,32 @@
package common
import (
"fmt"
"runtime/debug"
)
func SafeGoroutine(f func()) {
go func() {
defer func() {
if r := recover(); r != nil {
SysError(fmt.Sprintf("child goroutine panic occured: error: %v, stack: %s", r, string(debug.Stack())))
}
}()
f()
}()
}
func SafeSend(ch chan bool, value bool) (closed bool) {
defer func() {
// Recover from panic if one occured. A panic would mean the channel was closed.
if recover() != nil {
closed = true
}
}()
// This will panic if the channel is closed.
ch <- value
// If the code reaches here, then the channel was not closed.
return false
}

View File

@@ -109,7 +109,7 @@ func GetModelPrice(name string) float64 {
}
price, ok := ModelPrice[name]
if !ok {
SysError("model price not found: " + name)
//SysError("model price not found: " + name)
return -1
}
return price

View File

@@ -202,6 +202,13 @@ func GetOrDefault(env string, defaultValue int) int {
return num
}
func GetOrDefaultString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}

View File

@@ -35,6 +35,22 @@ func GetAllChannels(c *gin.Context) {
return
}
func FixChannelsAbilities(c *gin.Context) {
count, err := model.FixAbility()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")

View File

@@ -148,13 +148,7 @@ import (
*/
func UpdateMidjourneyTaskBulk() {
//revocer
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
imageModel := "midjourney"
//imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
@@ -167,13 +161,27 @@ func UpdateMidjourneyTaskBulk() {
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
nullTaskIds := make([]int, 0)
for _, task := range tasks {
if task.MjId == "" {
// 统计失败的未完成任务
nullTaskIds = append(nullTaskIds, task.Id)
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(nullTaskIds) > 0 {
err := model.MjBulkUpdateByTaskIds(nullTaskIds, map[string]any{
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogError(ctx, fmt.Sprintf("Fix null mj_id task error: %v", err))
} else {
common.LogInfo(ctx, fmt.Sprintf("Fix null mj_id task success: %v", nullTaskIds))
}
}
if len(taskChannelM) == 0 {
continue
}
@@ -226,7 +234,7 @@ func UpdateMidjourneyTaskBulk() {
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
continue
}
resp.Body.Close()
@@ -256,10 +264,7 @@ func UpdateMidjourneyTaskBulk() {
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
quota := task.Quota
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {

View File

@@ -16,26 +16,27 @@ func GetStatus(c *gin.Context) {
"success": true,
"message": "",
"data": gin.H{
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"price": common.Price,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_data_export": common.DataExportEnabled,
"start_time": common.StartTime,
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": common.ServerAddress,
"price": common.Price,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"enable_batch_update": common.BatchUpdateEnabled,
"enable_drawing": common.DrawingEnabled,
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
},
})
return

View File

@@ -106,13 +106,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiVersion := GetAPIVersion(c)
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion)
}
requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))

View File

@@ -60,24 +60,24 @@ type GeminiChatGenerationConfig struct {
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
// {
// Category: "HARM_CATEGORY_HARASSMENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_HATE_SPEECH",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
//},
SafetySettings: []GeminiChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: common.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: common.GeminiSafetySetting,
},
},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,

View File

@@ -31,7 +31,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
}
if imageRequest.Model == "" {
imageRequest.Model = "dall-e"
imageRequest.Model = "dall-e-2"
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
@@ -86,8 +86,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL = c.GetString("base_url")
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := GetAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped {
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
@@ -132,8 +138,14 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
token := c.Request.Header.Get("Authorization")
if channelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -158,6 +170,9 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse
defer func(ctx context.Context) {
if consumeQuota {
if resp.StatusCode != http.StatusOK {
return
}
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())

View File

@@ -78,6 +78,7 @@ func RelayMidjourneyImage(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
@@ -544,6 +545,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {

View File

@@ -83,7 +83,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
// wait data out
time.Sleep(2 * time.Second)
}
stopChan <- true
common.SafeSend(stopChan, true)
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {

View File

@@ -52,6 +52,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
var textRequest GeneralOpenAIRequest
@@ -261,10 +262,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
// 用户额度充足,判断令牌额度是否充足
if !tokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c.GetInt("token_quota")
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", userId, userQuota, tokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
}
if preConsumedQuota > 0 {
userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota)

View File

@@ -301,3 +301,12 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
}
return fullRequestURL
}
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}

View File

@@ -99,24 +99,37 @@ const (
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
RelayModeMidjourneyTaskFetchByCondition
RelayModeAudio
RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
)
// https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {
@@ -235,7 +248,7 @@ type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
@@ -291,14 +304,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
relayMode = RelayModeAudio
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
}
var err *OpenAIErrorWithStatusCode
switch relayMode {
case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode)
case RelayModeAudio:
case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode)
default:
err = relayTextHelper(c, relayMode)
@@ -311,7 +332,7 @@ func Relay(c *gin.Context) {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
} else {
if err.StatusCode == http.StatusTooManyRequests {
//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"

View File

@@ -124,14 +124,16 @@ func AddToken(c *gin.Context) {
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
UserId: c.GetInt("id"),
Name: token.Name,
Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
ModelLimitsEnabled: token.ModelLimitsEnabled,
ModelLimits: token.ModelLimits,
}
err = cleanToken.Insert()
if err != nil {

View File

@@ -31,6 +31,14 @@ func GetUserQuotaDates(c *gin.Context) {
userId := c.GetInt("id")
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
// 判断时间跨度是否超过 1 个月
if endTimestamp-startTimestamp > 2592000 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "时间跨度不能超过 1 个月",
})
return
}
dates, err := model.GetQuotaDataByUserId(userId, startTimestamp, endTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{

View File

@@ -85,7 +85,9 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
go controller.UpdateMidjourneyTaskBulk()
common.SafeGoroutine(func() {
controller.UpdateMidjourneyTaskBulk()
})
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@@ -115,6 +115,10 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
c.Set("token_unlimited_quota", token.UnlimitedQuota)
if !token.UnlimitedQuota {
c.Set("token_quota", token.RemainQuota)
}
if token.ModelLimitsEnabled {
c.Set("token_model_limit_enabled", true)
c.Set("token_model_limit", token.GetModelLimitsMap())

View File

@@ -101,13 +101,19 @@ func Distribute() func(c *gin.Context) {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)

View File

@@ -1,6 +1,8 @@
package model
import (
"errors"
"fmt"
"one-api/common"
"strings"
)
@@ -68,7 +70,7 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
}
} else {
return nil, nil
return nil, errors.New("channel not found")
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
@@ -118,3 +120,46 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func FixAbility() (int, error) {
var channelIds []int
count := 0
// Find all channel ids from channel table
err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
return 0, err
}
// Delete abilities of channels that are not in channel table
err = DB.Where("channel_id NOT IN (?)", channelIds).Delete(&Ability{}).Error
if err != nil {
common.SysError(fmt.Sprintf("Delete abilities of channels that are not in channel table failed: %s", err.Error()))
return 0, err
}
common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
count += len(channelIds)
// Use channelIds to find channel not in abilities table
var abilityChannelIds []int
err = DB.Model(&Ability{}).Pluck("channel_id", &abilityChannelIds).Error
if err != nil {
common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
return 0, err
}
var channels []Channel
err = DB.Where("id NOT IN (?)", abilityChannelIds).Find(&channels).Error
if err != nil {
return 0, err
}
for _, channel := range channels {
err := channel.UpdateAbilities()
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {
common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
count++
}
}
InitChannelCache()
return count, nil
}

View File

@@ -68,6 +68,24 @@ func CacheGetUserGroup(id int) (group string, err error) {
return group, err
}
func CacheGetUsername(id int) (username string, err error) {
if !common.RedisEnabled {
return GetUsernameById(id)
}
username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
if err != nil {
username, err = GetUsernameById(id)
if err != nil {
return "", err
}
err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user group error: " + err.Error())
}
}
return username, err
}
func CacheGetUserQuota(id int) (quota int, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
@@ -240,8 +258,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
return channel, nil
}
}
// return the last channel if no channel is found
return channels[endIdx-1], nil
// return null if no channel is not found
return nil, errors.New("channel not found")
}
func CacheGetChannel(id int) (*Channel, error) {

View File

@@ -41,9 +41,10 @@ func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !common.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
Username: username,
CreatedAt: common.GetTimestamp(),
Type: logType,
Content: content,
@@ -59,7 +60,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
if !common.LogConsumeEnabled {
return
}
username := GetUsernameById(userId)
username, _ := CacheGetUsername(userId)
log := &Log{
UserId: userId,
Username: username,
@@ -79,7 +80,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
common.LogError(ctx, "failed to record log: "+err.Error())
}
if common.DataExportEnabled {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp())
common.SafeGoroutine(func() {
LogQuotaData(userId, username, modelName, quota, common.GetTimestamp())
})
}
}

View File

@@ -18,6 +18,7 @@ type Midjourney struct {
Progress string `json:"progress"`
FailReason string `json:"fail_reason"`
ChannelId int `json:"channel_id"`
Quota int `json:"quota"`
}
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
@@ -152,8 +153,14 @@ func (midjourney *Midjourney) Update() error {
return err
}
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
func MjBulkUpdate(mjIds []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs).
Where("mj_id in (?)", mjIds).
Updates(params).Error
}
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("id in (?)", taskIDs).
Updates(params).Error
}

View File

@@ -79,6 +79,7 @@ func InitOptionMap() {
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -228,6 +229,8 @@ func updateOptionMap(key string, value string) (err error) {
common.RetryTimes, _ = strconv.Atoi(value)
case "DataExportInterval":
common.DataExportInterval, _ = strconv.Atoi(value)
case "DataExportDefaultTime":
common.DataExportDefaultTime = value
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":

View File

@@ -2,6 +2,7 @@ package model
import (
"fmt"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
@@ -37,9 +38,7 @@ func UpdateQuotaData() {
var CacheQuotaData = make(map[string]*QuotaData)
var CacheQuotaDataLock = sync.Mutex{}
func LogQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64) {
// 只精确到小时
createdAt = createdAt - (createdAt % 3600)
func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64) {
key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt)
quotaData, ok := CacheQuotaData[key]
if ok {
@@ -59,9 +58,12 @@ func LogQuotaDataCache(userId int, username string, modelName string, quota int,
}
func LogQuotaData(userId int, username string, modelName string, quota int, createdAt int64) {
// 只精确到小时
createdAt = createdAt - (createdAt % 3600)
CacheQuotaDataLock.Lock()
defer CacheQuotaDataLock.Unlock()
LogQuotaDataCache(userId, username, modelName, quota, createdAt)
logQuotaDataCache(userId, username, modelName, quota, createdAt)
}
func SaveQuotaDataCache() {
@@ -77,9 +79,10 @@ func SaveQuotaDataCache() {
DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.CreatedAt).First(quotaDataDB)
if quotaDataDB.Id > 0 {
quotaDataDB.Count += quotaData.Count
quotaDataDB.Quota += quotaData.Quota
DB.Table("quota_data").Save(quotaDataDB)
//quotaDataDB.Count += quotaData.Count
//quotaDataDB.Quota += quotaData.Quota
//DB.Table("quota_data").Save(quotaDataDB)
increaseQuotaData(quotaData.UserID, quotaData.Username, quotaData.ModelName, quotaData.Count, quotaData.Quota, quotaData.CreatedAt)
} else {
DB.Table("quota_data").Create(quotaData)
}
@@ -88,10 +91,21 @@ func SaveQuotaDataCache() {
common.SysLog(fmt.Sprintf("保存数据看板数据成功,共保存%d条数据", size))
}
func increaseQuotaData(userId int, username string, modelName string, count int, quota int, createdAt int64) {
err := DB.Table("quota_data").Where("user_id = ? and username = ? and model_name = ? and created_at = ?",
userId, username, modelName, createdAt).Updates(map[string]interface{}{
"count": gorm.Expr("count + ?", count),
"quota": gorm.Expr("quota + ?", quota),
}).Error
if err != nil {
common.SysLog(fmt.Sprintf("increaseQuotaData error: %s", err))
}
}
func GetQuotaDataByUsername(username string, startTime int64, endTime int64) (quotaData []*QuotaData, err error) {
var quotaDatas []*QuotaData
// 从quota_data表中查询数据
err = DB.Table("quota_data").Where("username = ?", username).Find(&quotaDatas).Error
err = DB.Table("quota_data").Where("username = ? and created_at >= ? and created_at <= ?", username, startTime, endTime).Find(&quotaDatas).Error
return quotaDatas, err
}

View File

@@ -26,7 +26,7 @@ type User struct {
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
@@ -452,7 +452,7 @@ func updateUserRequestCount(id int, count int) {
}
}
func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username
func GetUsernameById(id int) (username string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error
return username, err
}

View File

@@ -84,6 +84,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
modelsRouter.GET("/:model", controller.RetrieveModel)
}
relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
{
relayV1Router.POST("/completions", controller.Relay)
relayV1Router.POST("/chat/completions", controller.Relay)

View File

@@ -51,6 +51,7 @@ function App() {
localStorage.setItem('display_in_currency', data.display_in_currency);
localStorage.setItem('enable_drawing', data.enable_drawing);
localStorage.setItem('enable_data_export', data.enable_data_export);
localStorage.setItem('data_export_default_time', data.data_export_default_time);
if (data.chat_link) {
localStorage.setItem('chat_link', data.chat_link);
} else {

View File

@@ -1,6 +1,4 @@
import React, {useEffect, useState} from 'react';
import {Input, Label, Message, Popup} from 'semantic-ui-react';
import {Link} from 'react-router-dom';
import {
API,
isMobile,
@@ -389,23 +387,15 @@ const ChannelsTable = () => {
return <Tag size='large' color='green'>已启用</Tag>;
case 2:
return (
<Popup
trigger={<Tag size='large' color='red'>
已禁用
</Tag>}
content='本渠道被手动禁用'
basic
/>
<Tag size='large' color='yellow'>
已禁用
</Tag>
);
case 3:
return (
<Popup
trigger={<Tag size='large' color='yellow'>
已禁用
</Tag>}
content='本渠道被程序自动禁用'
basic
/>
<Tag size='large' color='yellow'>
自动禁用
</Tag>
);
default:
return (
@@ -531,6 +521,17 @@ const ChannelsTable = () => {
setLoading(false);
}
const fixChannelsAbilities = async () => {
const res = await API.post(`/api/channel/fix`);
const {success, message, data} = res.data;
if (success) {
showSuccess(`已修复 ${data} 个通道!`);
await refresh();
} else {
showError(message);
}
}
const sortChannel = (key) => {
if (channels.length === 0) return;
setLoading(true);
@@ -646,7 +647,7 @@ const ChannelsTable = () => {
</Space>
</div>
<Table columns={columns} dataSource={pageData} pagination={{
<Table style={{marginTop: 15}} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: pageSize,
total: channelCount,
@@ -722,6 +723,15 @@ const ChannelsTable = () => {
>
<Button disabled={!enableBatchDelete} theme='light' type='danger' style={{marginRight: 8}}>删除所选通道</Button>
</Popconfirm>
<Popconfirm
title="确定是否要修复数据库一致性?"
content="进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用"
okType={'warning'}
onConfirm={fixChannelsAbilities}
position={'top'}
>
<Button theme='light' type='secondary' style={{marginRight: 8}}>修复数据库一致性</Button>
</Popconfirm>
</Space>
</div>
</>

View File

@@ -2,7 +2,7 @@ import React, {useEffect, useState} from 'react';
import {Label} from 'semantic-ui-react';
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal} from '@douyinfe/semi-ui';
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal, Spin} from '@douyinfe/semi-ui';
import {ITEMS_PER_PAGE} from '../constants';
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
import {
@@ -194,7 +194,8 @@ const LogsTable = () => {
const [logs, setLogs] = useState([]);
const [showStat, setShowStat] = useState(false);
const [loading, setLoading] = useState(true);
const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState('');
@@ -247,14 +248,14 @@ const LogsTable = () => {
};
const handleEyeClick = async () => {
if (!showStat) {
if (isAdminUser) {
await getLogStat();
} else {
await getLogSelfStat();
}
setLoadingStat(true);
if (isAdminUser) {
await getLogStat();
} else {
await getLogSelfStat();
}
setShowStat(!showStat);
setShowStat(true);
setLoadingStat(false);
};
const showUserInfo = async (userId) => {
@@ -396,12 +397,12 @@ const LogsTable = () => {
<>
<Layout>
<Header>
<h3>使用明细总消耗额度
{showStat && renderQuota(stat.quota)}
{!showStat &&
<span onClick={handleEyeClick} style={{cursor: 'pointer', color: 'gray'}}>点击查看</span>}
</h3>
<Spin spinning={loadingStat}>
<h3>使用明细总消耗额度
<span onClick={handleEyeClick} style={{cursor: 'pointer', color: 'gray'}}>{showStat?renderQuota(stat.quota):"点击查看"}</span>
</h3>
</Spin>
</Header>
<Form layout='horizontal' style={{marginTop: 10}}>
<>
@@ -434,17 +435,17 @@ const LogsTable = () => {
}
<Form.Section>
<Button label='查询' type="primary" htmlType="submit" className="btn-margin-right"
onClick={refresh}>查询</Button>
onClick={refresh} loading={loading}>查询</Button>
</Form.Section>
</>
</Form>
<Table columns={columns} dataSource={pageData} pagination={{
<Table style={{marginTop: 5}} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
total: logCount,
pageSizeOpts: [10, 20, 50, 100],
onPageChange: handlePageChange,
}} loading={loading}/>
}}/>
<Select defaultValue="0" style={{width: 120}} onChange={
(value) => {
setLogType(parseInt(value));

View File

@@ -13,7 +13,7 @@ import {
Popover,
Modal,
ImagePreview,
Typography
Typography, Progress
} from '@douyinfe/semi-ui';
import {ITEMS_PER_PAGE} from '../constants';
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
@@ -25,84 +25,83 @@ const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
]
function renderType(type) {
switch (type) {
case 'IMAGINE':
return <Tag color="blue" size='large'>绘图</Tag>;
case 'UPSCALE':
return <Tag color="orange" size='large'>放大</Tag>;
case 'VARIATION':
return <Tag color="purple" size='large'>变换</Tag>;
case 'DESCRIBE':
return <Tag color="yellow" size='large'>图生文</Tag>;
case 'BLEAND':
return <Tag color="lime" size='large'>图混合</Tag>;
default:
return <Tag color="black" size='large'>未知</Tag>;
}
switch (type) {
case 'IMAGINE':
return <Tag color="blue" size='large'>绘图</Tag>;
case 'UPSCALE':
return <Tag color="orange" size='large'>放大</Tag>;
case 'VARIATION':
return <Tag color="purple" size='large'>变换</Tag>;
case 'DESCRIBE':
return <Tag color="yellow" size='large'>图生文</Tag>;
case 'BLEAND':
return <Tag color="lime" size='large'>图混合</Tag>;
default:
return <Tag color="white" size='large'>未知</Tag>;
}
}
function renderCode(code) {
switch (code) {
case 1:
return <Tag color="green" size='large'>已提交</Tag>;
case 21:
return <Tag color="lime" size='large'>排队中</Tag>;
case 22:
return <Tag color="orange" size='large'>重复提交</Tag>;
default:
return <Tag color="black" size='large'>未知</Tag>;
}
switch (code) {
case 1:
return <Tag color="green" size='large'>已提交</Tag>;
case 21:
return <Tag color="lime" size='large'>排队中</Tag>;
case 22:
return <Tag color="orange" size='large'>重复提交</Tag>;
default:
return <Tag color="white" size='large'>未知</Tag>;
}
}
function renderStatus(type) {
// Ensure all cases are string literals by adding quotes.
switch (type) {
case 'SUCCESS':
return <Tag color="green" size='large'>成功</Tag>;
case 'NOT_START':
return <Tag color="grey" size='large'>未启动</Tag>;
case 'SUBMITTED':
return <Tag color="yellow" size='large'>队列中</Tag>;
case 'IN_PROGRESS':
return <Tag color="blue" size='large'>执行中</Tag>;
case 'FAILURE':
return <Tag color="red" size='large'>失败</Tag>;
default:
return <Tag color="black" size='large'>未知</Tag>;
}
// Ensure all cases are string literals by adding quotes.
switch (type) {
case 'SUCCESS':
return <Tag color="green" size='large'>成功</Tag>;
case 'NOT_START':
return <Tag color="grey" size='large'>未启动</Tag>;
case 'SUBMITTED':
return <Tag color="yellow" size='large'>队列中</Tag>;
case 'IN_PROGRESS':
return <Tag color="blue" size='large'>执行中</Tag>;
case 'FAILURE':
return <Tag color="red" size='large'>失败</Tag>;
default:
return <Tag color="white" size='large'>未知</Tag>;
}
}
const renderTimestamp = (timestampInSeconds) => {
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
const year = date.getFullYear(); // 获取年份
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份从0开始需要+1并保证两位数
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
const year = date.getFullYear(); // 获取年份
const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份从0开始需要+1并保证两位数
const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
};
const LogsTable = () => {
const [isModalOpen, setIsModalOpen] = useState(false);
const [modalContent, setModalContent] = useState('');
const columns = [
{
title: '提交时间',
dataIndex: 'submit_time',
render: (text, record, index) => {
return (
<div>
{renderTimestamp(text / 1000)}
</div>
);
},
title: '提交时间',
dataIndex: 'submit_time',
render: (text, record, index) => {
return (
<div>
{renderTimestamp(text / 1000)}
</div>
);
},
},
{
title: '渠道',
@@ -111,48 +110,48 @@ const LogsTable = () => {
render: (text, record, index) => {
return (
<div>
<Tag color={colors[parseInt(text) % colors.length]} size='large' onClick={()=>{
copyText(text); // 假设copyText是用于文本复制的函数
}}> {text} </Tag>
</div>
<div>
<Tag color={colors[parseInt(text) % colors.length]} size='large' onClick={() => {
copyText(text); // 假设copyText是用于文本复制的函数
}}> {text} </Tag>
</div>
);
},
},
{
title: '类型',
dataIndex: 'action',
render: (text, record, index) => {
return (
<div>
{renderType(text)}
</div>
);
},
title: '类型',
dataIndex: 'action',
render: (text, record, index) => {
return (
<div>
{renderType(text)}
</div>
);
},
},
{
title: '任务ID',
dataIndex: 'mj_id',
render: (text, record, index) => {
return (
<div>
{text}
</div>
<div>
{text}
</div>
);
},
},
{
title: '提交结果',
dataIndex: 'code',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderCode(text)}
</div>
);
},
title: '提交结果',
dataIndex: 'code',
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderCode(text)}
</div>
);
},
},
{
title: '任务状态',
@@ -160,9 +159,9 @@ const LogsTable = () => {
className: isAdmin() ? 'tableShow' : 'tableHiddle',
render: (text, record, index) => {
return (
<div>
{renderStatus(text)}
</div>
<div>
{renderStatus(text)}
</div>
);
},
},
@@ -171,99 +170,103 @@ const LogsTable = () => {
dataIndex: 'progress',
render: (text, record, index) => {
return (
<div>
{<span> {text} </span>}
</div>
<div>
{
// 转换例如100%为数字100如果text未定义返回0
<Progress stroke={record.status === "FAILURE"?"var(--semi-color-warning)":null} percent={text ? parseInt(text.replace('%', '')) : 0} showInfo={true}
aria-label="drawing progress"/>
}
</div>
);
},
},
{
title: '结果图片',
dataIndex: 'image_url',
render: (text, record, index) => {
if (!text) {
return '无';
title: '结果图片',
dataIndex: 'image_url',
render: (text, record, index) => {
if (!text) {
return '无';
}
return (
<Button
onClick={() => {
setModalImageUrl(text); // 更新图片URL状态
setIsModalOpenurl(true); // 打开模态框
}}
>
查看图片
</Button>
);
}
return (
<Button
onClick={() => {
setModalImageUrl(text); // 更新图片URL状态
setIsModalOpenurl(true); // 打开模态框
}}
>
查看图片
</Button>
);
}
},
{
title: 'Prompt',
dataIndex: 'prompt',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{showTooltip: true}}
style={{width: 100}}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
{
title: 'PromptEn',
dataIndex: 'prompt_en',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{showTooltip: true}}
style={{width: 100}}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
{
title: '失败原因',
dataIndex: 'fail_reason',
render: (text, record, index) => {
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
// 如果text未定义返回替代文本例如空字符串''或其他
if (!text) {
return '无';
}
return (
<Typography.Text
ellipsis={{showTooltip: true}}
style={{width: 100}}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
}
];
@@ -286,7 +289,7 @@ const LogsTable = () => {
start_timestamp: timestamp2string(now.getTime() / 1000 - 2592000),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
});
const {channel_id, mj_id, start_timestamp, end_timestamp} = inputs;
const {channel_id, mj_id, start_timestamp, end_timestamp} = inputs;
const [stat, setStat] = useState({
quota: 0,
@@ -298,7 +301,6 @@ const LogsTable = () => {
};
const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
@@ -317,9 +319,9 @@ const LogsTable = () => {
let localStartTimestamp = Date.parse(start_timestamp);
let localEndTimestamp = Date.parse(end_timestamp);
if (isAdminUser) {
url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/mj/?p=${startIdx}&channel_id=${channel_id}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} else {
url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/mj/self/?p=${startIdx}&mj_id=${mj_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
const res = await API.get(url);
const {success, message, data} = res.data;
@@ -368,11 +370,9 @@ const LogsTable = () => {
}, [logType]);
return (
<>
<Layout>
<Form layout='horizontal' style={{marginTop: 10}}>
<>
@@ -400,7 +400,7 @@ const LogsTable = () => {
</Form.Section>
</>
</Form>
<Table columns={columns} dataSource={pageData} pagination={{
<Table style={{marginTop: 5}} columns={columns} dataSource={pageData} pagination={{
currentPage: activePage,
pageSize: ITEMS_PER_PAGE,
total: logCount,
@@ -412,17 +412,17 @@ const LogsTable = () => {
onOk={() => setIsModalOpen(false)}
onCancel={() => setIsModalOpen(false)}
closable={null}
bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
bodyStyle={{height: '400px', overflow: 'auto'}} // 设置模态框内容区域样式
width={800} // 设置模态框宽度
>
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
<p style={{whiteSpace: 'pre-line'}}>{modalContent}</p>
</Modal>
<ImagePreview
src={modalImageUrl}
visible={isModalOpenurl}
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
/>
</Layout>
</>
);

View File

@@ -23,13 +23,19 @@ const OperationSetting = () => {
DisplayTokenStatEnabled: '',
DrawingEnabled: '',
DataExportEnabled: '',
DataExportDefaultTime: 'hour',
DataExportInterval: 5,
RetryTimes: 0
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
// 精确时间选项(小时,天,周)
const timeOptions = [
{key: 'hour', text: '小时', value: 'hour'},
{key: 'day', text: '天', value: 'day'},
{key: 'week', text: '周', value: 'week'}
];
const getOptions = async () => {
const res = await API.get('/api/option/');
const {success, message, data} = res.data;
@@ -71,7 +77,10 @@ const OperationSetting = () => {
};
const handleInputChange = async (e, {name, value}) => {
if (name.endsWith('Enabled') || name === 'DataExportInterval') {
if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime') {
if (name === 'DataExportDefaultTime') {
localStorage.setItem('data_export_default_time', value);
}
await updateOption(name, value);
} else {
setInputs((inputs) => ({...inputs, [name]: value}));
@@ -234,15 +243,28 @@ const OperationSetting = () => {
name='LogConsumeEnabled'
onChange={handleInputChange}
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox
checked={inputs.DataExportEnabled === 'true'}
label='启用数据看板(实验性)'
name='DataExportEnabled'
onChange={handleInputChange}
/>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, {name, value}) => {
setHistoryTimestamp(value);
}}/>
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider/>
<Header as='h3'>
数据看板
</Header>
<Form.Checkbox
checked={inputs.DataExportEnabled === 'true'}
label='启用数据看板(实验性)'
name='DataExportEnabled'
onChange={handleInputChange}
/>
<Form.Group>
<Form.Input
label='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
name='DataExportInterval'
@@ -254,19 +276,17 @@ const OperationSetting = () => {
value={inputs.DataExportInterval}
placeholder='数据看板更新间隔(分钟,设置过短会影响数据库性能)'
/>
<Form.Select
label='数据看板默认时间粒度(仅修改展示粒度,统计精确到小时)'
options={timeOptions}
name='DataExportDefaultTime'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DataExportDefaultTime}
placeholder='数据看板默认时间粒度'
/>
</Form.Group>
<Divider/>
<Form.Group widths={4}>
<Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
name='history_timestamp'
onChange={(e, {name, value}) => {
setHistoryTimestamp(value);
}}/>
</Form.Group>
<Form.Button onClick={() => {
deleteHistoryLogs().then();
}}>清理历史日志</Form.Button>
<Divider/>
<Header as='h3'>
监控设置
</Header>

View File

@@ -116,6 +116,45 @@ const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
'purple', 'red', 'teal', 'violet', 'yellow'
]
export const modelColorMap = {
'dall-e': 'rgb(147,112,219)', // 深紫色
'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调
'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调
'midjourney': 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调
'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色
'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色
'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
'gpt-4': 'rgb(135,206,235)', // 天蓝色
'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
'gpt-4-0613': 'rgb(100,149,237)', // 矢车菊蓝
'gpt-4-1106-preview': 'rgb(30,144,255)', // 道奇蓝
'gpt-4-32k': 'rgb(104,111,238)', // 中紫色
'gpt-4-32k-0314': 'rgb(90,105,205)', // 暗灰蓝色
'gpt-4-32k-0613': 'rgb(61,71,139)', // 暗蓝灰色
'gpt-4-all': 'rgb(65,105,225)', // 皇家蓝
'gpt-4-gizmo-*': 'rgb(0,0,255)', // 纯蓝色
'gpt-4-vision-preview': 'rgb(25,25,112)', // 午夜蓝
'text-ada-001': 'rgb(255,192,203)', // 粉红色
'text-babbage-001': 'rgb(255,160,122)', // 浅珊瑚色
'text-curie-001': 'rgb(219,112,147)', // 苍紫罗兰色
'text-davinci-002': 'rgb(199,21,133)', // 中紫罗兰红色
'text-davinci-003': 'rgb(219,112,147)', // 苍紫罗兰色与Curie相同表示同一个系列
'text-davinci-edit-001': 'rgb(255,105,180)', // 热粉色
'text-embedding-ada-002': 'rgb(255,182,193)', // 浅粉红
'text-embedding-v1': 'rgb(255,174,185)', // 浅粉红色(略有区别)
'text-moderation-latest': 'rgb(255,130,171)', // 强粉色
'text-moderation-stable': 'rgb(255,160,122)', // 浅珊瑚色与Babbage相同表示同一类功能
'tts-1': 'rgb(255,140,0)', // 深橙色
'tts-1-1106': 'rgb(255,165,0)', // 橙色
'tts-1-hd': 'rgb(255,215,0)', // 金色
'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别)
'whisper-1': 'rgb(245,245,220)' // 米色
}
export function stringToColor(str) {
let sum = 0;
// 对字符串中的每个字符进行操作

View File

@@ -171,7 +171,7 @@ export function timestamp2string(timestamp) {
);
}
export function timestamp2string1(timestamp) {
export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') {
let date = new Date(timestamp * 1000);
// let year = date.getFullYear().toString();
let month = (date.getMonth() + 1).toString();
@@ -186,15 +186,22 @@ export function timestamp2string1(timestamp) {
if (hour.length === 1) {
hour = '0' + hour;
}
return (
// year +
// '-' +
month +
'-' +
day +
' ' +
hour + ":00"
);
let str = month + '-' + day
if (dataExportDefaultTime === 'hour') {
str += ' ' + hour + ":00"
} else if (dataExportDefaultTime === 'week') {
let nextWeek = new Date(timestamp * 1000 + 6 * 24 * 60 * 60 * 1000);
let nextMonth = (nextWeek.getMonth() + 1).toString();
let nextDay = nextWeek.getDate().toString();
if (nextMonth.length === 1) {
nextMonth = '0' + nextMonth;
}
if (nextDay.length === 1) {
nextDay = '0' + nextDay;
}
str += ' - ' + nextMonth + '-' + nextDay
}
return str;
}
export function downloadTextAsFile(text, filename) {

View File

@@ -3,30 +3,42 @@ import {Button, Col, Form, Layout, Row, Spin} from "@douyinfe/semi-ui";
import VChart from '@visactor/vchart';
import {useEffectOnce} from "usehooks-ts";
import {API, isAdmin, showError, timestamp2string, timestamp2string1} from "../../helpers";
import {getQuotaWithUnit, renderNumber, renderQuotaNumberWithDigit} from "../../helpers/render";
import {
getQuotaWithUnit, modelColorMap,
renderNumber,
renderQuota,
renderQuotaNumberWithDigit,
stringToColor
} from "../../helpers/render";
const Detail = (props) => {
const formRef = useRef();
let now = new Date();
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
start_timestamp: localStorage.getItem('data_export_default_time') === 'hour' ? timestamp2string(now.getTime() / 1000 - 86400) : (localStorage.getItem('data_export_default_time') === 'week' ? timestamp2string(now.getTime() / 1000 - 86400 * 30) : timestamp2string(now.getTime() / 1000 - 86400 * 7)),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: ''
channel: '',
data_export_default_time: ''
});
const {username, token_name, model_name, start_timestamp, end_timestamp, channel} = inputs;
const {username, model_name, start_timestamp, end_timestamp, channel} = inputs;
const isAdminUser = isAdmin();
const initialized = useRef(false)
const [modelDataChart, setModelDataChart] = useState(null);
const [modelDataPieChart, setModelDataPieChart] = useState(null);
const [loading, setLoading] = useState(false);
const [quotaData, setQuotaData] = useState([]);
const [quotaDataPie, setQuotaDataPie] = useState([]);
const [quotaDataLine, setQuotaDataLine] = useState([]);
const [consumeQuota, setConsumeQuota] = useState(0);
const [times, setTimes] = useState(0);
const [dataExportDefaultTime, setDataExportDefaultTime] = useState(localStorage.getItem('data_export_default_time') || 'hour');
const handleInputChange = (value, name) => {
if (name === 'data_export_default_time') {
setDataExportDefaultTime(value);
return
}
setInputs((inputs) => ({...inputs, [name]: value}));
};
@@ -35,8 +47,7 @@ const Detail = (props) => {
data: [
{
id: 'barData',
values: [
]
values: []
}
],
xField: 'Time',
@@ -48,7 +59,8 @@ const Detail = (props) => {
},
title: {
visible: true,
text: '模型消耗分布(小时)'
text: '模型消耗分布',
subtext: '0'
},
bar: {
// The state style of bar
@@ -64,7 +76,7 @@ const Detail = (props) => {
content: [
{
key: datum => datum['Model'],
value: datum => renderQuotaNumberWithDigit(datum['Usage'], 4)
value: datum => renderQuotaNumberWithDigit(parseFloat(datum['Usage']), 4)
}
]
},
@@ -80,11 +92,14 @@ const Detail = (props) => {
array.sort((a, b) => b.value - a.value);
// add $
for (let i = 0; i < array.length; i++) {
array[i].value = renderQuotaNumberWithDigit(array[i].value, 4);
array[i].value = renderQuotaNumberWithDigit(parseFloat(array[i].value), 4);
}
return array;
}
}
},
color: {
specified: modelColorMap
}
};
@@ -94,7 +109,7 @@ const Detail = (props) => {
{
id: 'id0',
values: [
{ type: 'null', value: '0' },
{type: 'null', value: '0'},
]
}
],
@@ -140,6 +155,9 @@ const Detail = (props) => {
}
]
}
},
color: {
specified: modelColorMap
}
};
@@ -150,9 +168,9 @@ const Detail = (props) => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/data/?username=${username}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/data/?username=${username}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&default_time=${dataExportDefaultTime}`;
} else {
url = `/api/data/self/?start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/data/self/?start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&default_time=${dataExportDefaultTime}`;
}
const res = await API.get(url);
const {success, message, data} = res.data;
@@ -166,6 +184,16 @@ const Detail = (props) => {
'created_at': now.getTime() / 1000
})
}
// 根据dataExportDefaultTime重制时间粒度
let timeGranularity = 3600;
if (dataExportDefaultTime === 'day') {
timeGranularity = 86400;
} else if (dataExportDefaultTime === 'week') {
timeGranularity = 604800;
}
data.forEach(item => {
item['created_at'] = Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
});
updateChart(lineChart, pieChart, data);
} else {
showError(message);
@@ -177,7 +205,7 @@ const Detail = (props) => {
await loadQuotaData(modelDataChart, modelDataPieChart);
};
const initChart = async () => {
const initChart = async () => {
let lineChart = modelDataChart
if (!modelDataChart) {
lineChart = new VChart(spec_line, {dom: 'model_data'});
@@ -200,8 +228,12 @@ const Detail = (props) => {
}
let pieData = [];
let lineData = [];
let consumeQuota = 0;
let times = 0;
for (let i = 0; i < data.length; i++) {
const item = data[i];
consumeQuota += item.quota;
times += item.count;
// 合并model_name
let pieItem = pieData.find(it => it.type === item.model_name);
if (pieItem) {
@@ -214,7 +246,7 @@ const Detail = (props) => {
}
// 合并created_at和model_name 为 lineData, created_at 数据类型是小时的时间戳
// 转换日期格式
let createTime = timestamp2string1(item.created_at);
let createTime = timestamp2string1(item.created_at, dataExportDefaultTime);
let lineItem = lineData.find(it => it.Time === createTime && it.Model === item.model_name);
if (lineItem) {
lineItem.Usage += parseFloat(getQuotaWithUnit(item.quota));
@@ -225,17 +257,34 @@ const Detail = (props) => {
"Usage": parseFloat(getQuotaWithUnit(item.quota))
});
}
}
setConsumeQuota(consumeQuota);
setTimes(times);
// sort by count
pieData.sort((a, b) => b.value - a.value);
pieChart.updateData('id0', pieData);
lineChart.updateData('barData', lineData);
spec_pie.title.subtext = `总计:${renderNumber(times)}`;
spec_pie.data[0].values = pieData;
spec_line.title.subtext = `总计:${renderQuota(consumeQuota, 2)}`;
spec_line.data[0].values = lineData;
pieChart.updateSpec(spec_pie);
lineChart.updateSpec(spec_line);
// pieChart.updateData('id0', pieData);
// lineChart.updateData('barData', lineData);
pieChart.reLayout();
lineChart.reLayout();
}
useEffect(() => {
// setDataExportDefaultTime(localStorage.getItem('data_export_default_time'));
// if (dataExportDefaultTime === 'day') {
// // 设置开始时间为7天前
// let st = timestamp2string(now.getTime() / 1000 - 86400 * 7)
// inputs.start_timestamp = st;
// formRef.current.formApi.setValue('start_timestamp', st);
// }
if (!initialized.current) {
initialized.current = true;
initChart();
@@ -249,7 +298,7 @@ const Detail = (props) => {
<h3>数据看板</h3>
</Layout.Header>
<Layout.Content>
<Form layout='horizontal' style={{marginTop: 10}}>
<Form ref={formRef} layout='horizontal' style={{marginTop: 10}}>
<>
<Form.DatePicker field="start_timestamp" label='起始时间' style={{width: 272}}
initValue={start_timestamp}
@@ -261,6 +310,18 @@ const Detail = (props) => {
value={end_timestamp} type='dateTime'
name='end_timestamp'
onChange={value => handleInputChange(value, 'end_timestamp')}/>
<Form.Select field="data_export_default_time" label='时间粒度' style={{width: 176}}
initValue={dataExportDefaultTime}
placeholder={'时间粒度'} name='data_export_default_time'
optionList={
[
{label: '小时', value: 'hour'},
{label: '天', value: 'day'},
{label: '周', value: 'week'}
]
}
onChange={value => handleInputChange(value, 'data_export_default_time')}>
</Form.Select>
{
isAdminUser && <>
<Form.Input field="username" label='用户名称' style={{width: 176}} value={username}

View File

@@ -1,5 +1,4 @@
import React from 'react';
import { Header, Segment } from 'semantic-ui-react';
import MjLogsTable from '../../components/MjLogsTable';
const Midjourney = () => (