mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-13 07:17:26 +00:00
Compare commits
43 Commits
v0.0.6
...
0.1.0-alph
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bdd611fd33 | ||
|
|
1a8a24698f | ||
|
|
c09f3b9282 | ||
|
|
e8bcf60f0a | ||
|
|
14592f9758 | ||
|
|
4036355fae | ||
|
|
fa2efb7357 | ||
|
|
7a8344c40a | ||
|
|
6ad2544415 | ||
|
|
7cd1261a81 | ||
|
|
45ed973f1c | ||
|
|
07fe5a02af | ||
|
|
7a4969c238 | ||
|
|
ad6842da7f | ||
|
|
3c52a0991f | ||
|
|
fd4ef086dc | ||
|
|
7c4719b6ee | ||
|
|
b9d040cf52 | ||
|
|
6e8ff8c057 | ||
|
|
676dc95793 | ||
|
|
1c06bddafe | ||
|
|
3475643257 | ||
|
|
45e1042e58 | ||
|
|
f5a36a05e5 | ||
|
|
5730c69385 | ||
|
|
2d33283afb | ||
|
|
e057c0e42e | ||
|
|
b3f1da44dd | ||
|
|
f048cefed1 | ||
|
|
0226d94ea6 | ||
|
|
42469cb782 | ||
|
|
e1da1e31d5 | ||
|
|
0fdd4fc6e3 | ||
|
|
261dc43c4e | ||
|
|
6463e0539f | ||
|
|
c5f08a757d | ||
|
|
8a9bd08d66 | ||
|
|
751c33a6c0 | ||
|
|
57f664d0fa | ||
|
|
cdcbebce6d | ||
|
|
a29e3765f4 | ||
|
|
766e20719d | ||
|
|
4b93f185bb |
18
README.md
18
README.md
@@ -1,5 +1,5 @@
|
||||
|
||||
# Neko API
|
||||
# New API
|
||||
|
||||
> [!NOTE]
|
||||
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
|
||||
@@ -29,7 +29,21 @@
|
||||
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销
|
||||
5. 渠道显示已使用额度,支持指定组织访问
|
||||
6. 分页支持选择每页显示数量
|
||||
7. 支持gpt-4-1106-vision-preview,dall-e-3,tts-1
|
||||
7. 支持 gpt-4-1106-vision-preview,dall-e-3,tts-1
|
||||
8. 支持第三方模型 **gps** (gpt-4-gizmo-*),在渠道中添加自定义模型gpt-4-gizmo-*即可
|
||||
9. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db)
|
||||
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
|
||||
11. 支持gemini-pro模型
|
||||
|
||||
## 部署
|
||||
### 基于 Docker 进行部署
|
||||
```shell
|
||||
# 使用 SQLite 的部署命令:
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
|
||||
# 例如:
|
||||
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
|
||||
```
|
||||
|
||||
## 交流群
|
||||
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="500">
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
var StartTime = time.Now().Unix() // unit: second
|
||||
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
|
||||
var SystemName = "One API"
|
||||
var SystemName = "New API"
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var PayAddress = ""
|
||||
var EpayId = ""
|
||||
@@ -190,6 +190,7 @@ const (
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SendEmail(subject string, receiver string, content string) error {
|
||||
@@ -16,8 +17,9 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
||||
"From: %s<%s>\r\n"+
|
||||
"Subject: %s\r\n"+
|
||||
"Date: %s\r\n"+
|
||||
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, content))
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
|
||||
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
|
||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
||||
to := strings.Split(receiver, ";")
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, error) {
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
base64String = base64String[idx+1:]
|
||||
@@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) {
|
||||
decodedData, err := base64.StdEncoding.DecodeString(base64String)
|
||||
if err != nil {
|
||||
fmt.Println("Error: Failed to decode base64 string")
|
||||
return image.Config{}, err
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
// 创建一个bytes.Buffer用于存储解码后的数据
|
||||
reader := bytes.NewReader(decodedData)
|
||||
config, err := getImageConfig(reader)
|
||||
return config, err
|
||||
config, format, err := getImageConfig(reader)
|
||||
return config, format, err
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, error) {
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := http.Head(url)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
_, err = buffer.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mimeType = resp.Header.Get("Content-Type")
|
||||
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
|
||||
response, err := http.Get(imageUrl)
|
||||
if err != nil {
|
||||
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
|
||||
return image.Config{}, err
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
|
||||
// 限制读取的字节数,防止下载整个图片
|
||||
@@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) {
|
||||
// log.Fatal(err)
|
||||
//}
|
||||
//log.Printf("%x", data)
|
||||
config, err := getImageConfig(limitReader)
|
||||
config, format, err := getImageConfig(limitReader)
|
||||
response.Body.Close()
|
||||
return config, err
|
||||
return config, format, err
|
||||
}
|
||||
|
||||
func getImageConfig(reader io.Reader) (image.Config, error) {
|
||||
func getImageConfig(reader io.Reader) (image.Config, string, error) {
|
||||
// 读取图片的头部信息来获取图片尺寸
|
||||
config, _, err := image.DecodeConfig(reader)
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
|
||||
SysLog(err.Error())
|
||||
@@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) {
|
||||
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
|
||||
SysLog(err.Error())
|
||||
}
|
||||
format = "webp"
|
||||
}
|
||||
if err != nil {
|
||||
return image.Config{}, err
|
||||
return image.Config{}, "", err
|
||||
}
|
||||
return config, nil
|
||||
return config, format, nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ var (
|
||||
)
|
||||
|
||||
func printHelp() {
|
||||
fmt.Println("One API " + Version + " - All in one API service for OpenAI API.")
|
||||
fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
|
||||
fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
|
||||
fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
|
||||
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
|
||||
@@ -36,7 +36,14 @@ func init() {
|
||||
}
|
||||
|
||||
if os.Getenv("SESSION_SECRET") != "" {
|
||||
SessionSecret = os.Getenv("SESSION_SECRET")
|
||||
ss := os.Getenv("SESSION_SECRET")
|
||||
if ss == "random_string" {
|
||||
log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
|
||||
log.Println("警告:SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
|
||||
log.Fatal("Please set SESSION_SECRET to a random string.")
|
||||
} else {
|
||||
SessionSecret = ss
|
||||
}
|
||||
}
|
||||
if os.Getenv("SQLITE_PATH") != "" {
|
||||
SQLitePath = os.Getenv("SQLITE_PATH")
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
// 1 === ¥0.014 / 1k tokens
|
||||
var ModelRatio = map[string]float64{
|
||||
"midjourney": 50,
|
||||
"gpt-4-gizmo-*": 15,
|
||||
"gpt-4": 15,
|
||||
"gpt-4-0314": 15,
|
||||
"gpt-4-0613": 15,
|
||||
@@ -23,6 +24,7 @@ var ModelRatio = map[string]float64{
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
@@ -59,6 +61,8 @@ var ModelRatio = map[string]float64{
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
@@ -74,6 +78,35 @@ var ModelRatio = map[string]float64{
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
}
|
||||
|
||||
var ModelPrice = map[string]float64{
|
||||
"gpt-4-gizmo-*": 0.1,
|
||||
}
|
||||
|
||||
func ModelPrice2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelPrice)
|
||||
if err != nil {
|
||||
SysError("error marshalling model price: " + err.Error())
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
ModelPrice = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
|
||||
}
|
||||
|
||||
func GetModelPrice(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
price, ok := ModelPrice[name]
|
||||
if !ok {
|
||||
//SysError("model price not found: " + name)
|
||||
return -1
|
||||
}
|
||||
return price
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
jsonBytes, err := json.Marshal(ModelRatio)
|
||||
if err != nil {
|
||||
@@ -88,6 +121,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
ratio, ok := ModelRatio[name]
|
||||
if !ok {
|
||||
SysError("model ratio not found: " + name)
|
||||
|
||||
@@ -168,6 +168,11 @@ func GetRandomString(length int) string {
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func GetRandomInt(max int) int {
|
||||
//rand.Seed(time.Now().UnixNano())
|
||||
return rand.Intn(max)
|
||||
}
|
||||
|
||||
func GetTimestamp() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
||||
fallthrough
|
||||
case common.ChannelType360:
|
||||
fallthrough
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeXunfei:
|
||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||
case common.ChannelTypeAzure:
|
||||
|
||||
@@ -151,6 +151,36 @@ func DeleteDisabledChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type ChannelBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
|
||||
func DeleteChannelBatch(c *gin.Context) {
|
||||
channelBatch := ChannelBatch{}
|
||||
err := c.ShouldBindJSON(&channelBatch)
|
||||
if err != nil || len(channelBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.BatchDeleteChannels(channelBatch.Ids)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": len(channelBatch.Ids),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
|
||||
@@ -16,8 +16,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTask() {
|
||||
/*func UpdateMidjourneyTask() {
|
||||
//revocer
|
||||
//imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
imageModel := "midjourney"
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) != 0 {
|
||||
log.Printf("检测到未完成的任务数有: %v", len(tasks))
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
for _, task := range tasks {
|
||||
log.Printf("未完成的任务信息: %v", task)
|
||||
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
|
||||
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask: %v", err)
|
||||
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
|
||||
task.FailReason = fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", task.ChannelId)
|
||||
task.Status = "FAILURE"
|
||||
task.Progress = "100%"
|
||||
err := task.Update()
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
continue
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
|
||||
log.Printf("requestUrl: %s", requestUrl)
|
||||
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
|
||||
|
||||
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask error: %v", err)
|
||||
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
log.Println(task.MjId + " 构建失败," + task.FailReason)
|
||||
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
@@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
|
||||
if err != nil {
|
||||
log.Println("fail to increase user quota")
|
||||
}
|
||||
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, 1, logContent)
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func UpdateMidjourneyTaskBulk() {
|
||||
//revocer
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("UpdateMidjourneyTask panic: %v", err)
|
||||
}
|
||||
}()
|
||||
imageModel := "midjourney"
|
||||
ctx := context.TODO()
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
|
||||
taskChannelM := make(map[int][]string)
|
||||
taskM := make(map[string]*model.Midjourney)
|
||||
for _, task := range tasks {
|
||||
if task.MjId == "" {
|
||||
continue
|
||||
}
|
||||
taskM[task.MjId] = task
|
||||
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
continue
|
||||
}
|
||||
midjourneyChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err := model.MjBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
|
||||
|
||||
body, _ := json.Marshal(map[string]any{
|
||||
"ids": taskIds,
|
||||
})
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
|
||||
continue
|
||||
}
|
||||
// 设置超时时间
|
||||
timeout := time.Second * 5
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
// 使用带有超时的 context 创建新的请求
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
continue
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
continue
|
||||
}
|
||||
var responseItems []Midjourney
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
req.Body.Close()
|
||||
cancel()
|
||||
|
||||
for _, responseItem := range responseItems {
|
||||
task := taskM[responseItem.MjId]
|
||||
if !checkMjTaskNeedUpdate(task, responseItem) {
|
||||
continue
|
||||
}
|
||||
|
||||
task.Code = 1
|
||||
task.Progress = responseItem.Progress
|
||||
task.PromptEn = responseItem.PromptEn
|
||||
task.State = responseItem.State
|
||||
task.SubmitTime = responseItem.SubmitTime
|
||||
task.StartTime = responseItem.StartTime
|
||||
task.FinishTime = responseItem.FinishTime
|
||||
task.ImageUrl = responseItem.ImageUrl
|
||||
task.Status = responseItem.Status
|
||||
task.FailReason = responseItem.FailReason
|
||||
if task.Progress != "100%" && responseItem.FailReason != "" {
|
||||
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(task.UserId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
ratio := modelRatio * groupRatio
|
||||
quota := int(ratio * 1 * 1000)
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
|
||||
if oldTask.Code != 1 {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != newTask.Progress {
|
||||
return true
|
||||
}
|
||||
if oldTask.PromptEn != newTask.PromptEn {
|
||||
return true
|
||||
}
|
||||
if oldTask.State != newTask.State {
|
||||
return true
|
||||
}
|
||||
if oldTask.SubmitTime != newTask.SubmitTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.StartTime != newTask.StartTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.ImageUrl != newTask.ImageUrl {
|
||||
return true
|
||||
}
|
||||
if oldTask.Status != newTask.Status {
|
||||
return true
|
||||
}
|
||||
if oldTask.FailReason != newTask.FailReason {
|
||||
return true
|
||||
}
|
||||
if oldTask.FinishTime != newTask.FinishTime {
|
||||
return true
|
||||
}
|
||||
if oldTask.Progress != "100%" && newTask.FailReason != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
@@ -187,6 +364,12 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
if !strings.Contains(common.ServerAddress, "localhost") {
|
||||
for i, midjourney := range logs {
|
||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
logs[i] = midjourney
|
||||
}
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -261,6 +261,15 @@ func init() {
|
||||
Root: "gpt-4-vision-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-1106-vision-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-1106-vision-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-ada-002",
|
||||
Object: "model",
|
||||
@@ -414,6 +423,24 @@ func init() {
|
||||
Root: "PaLM-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro-vision",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro-vision",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_turbo",
|
||||
Object: "model",
|
||||
|
||||
@@ -18,7 +18,7 @@ type ClaudeMetadata struct {
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
|
||||
336
controller/relay-gemini.go
Normal file
336
controller/relay-gemini.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
//SafetySettings: []GeminiChatSafetySettings{
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_HARASSMENT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
// {
|
||||
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// Threshold: "BLOCK_ONLY_HIGH",
|
||||
// },
|
||||
//},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: string(message.Content),
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
|
||||
if part.Type == ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
if content.Role == "system" {
|
||||
content.Role = "user"
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
return g.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
content, _ := json.Marshal("")
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: i,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
}
|
||||
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = content
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &stopFinishReason
|
||||
var response ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
@@ -31,7 +31,7 @@ type PaLMChatRequest struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
TopK uint `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
|
||||
@@ -26,6 +26,7 @@ const (
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
@@ -119,6 +120,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
@@ -180,6 +183,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
case APITypeGemini:
|
||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
||||
if baseURL != "" {
|
||||
requestBaseURL = baseURL
|
||||
}
|
||||
version := "v1beta"
|
||||
if c.GetString("api_version") != "" {
|
||||
version = c.GetString("api_version")
|
||||
}
|
||||
action := "generateContent"
|
||||
if textRequest.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
//log.Println(fullRequestURL)
|
||||
|
||||
case APITypeZhipu:
|
||||
method := "invoke"
|
||||
if textRequest.Stream {
|
||||
@@ -209,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
case RelayModeModerations:
|
||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
||||
}
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||
}
|
||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||
modelPrice := common.GetModelPrice(textRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
if modelPrice == -1 {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
|
||||
}
|
||||
modelRatio = common.GetModelRatio(textRequest.Model)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -280,6 +312,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeGemini:
|
||||
geminiChatRequest := requestOpenAI2Gemini(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeZhipu:
|
||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
@@ -375,13 +414,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypeGemini:
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if apiType != APITypeGemini {
|
||||
// 设置公共头部...
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
|
||||
resp, err = httpClient.Do(req)
|
||||
@@ -418,15 +462,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
go func() {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
quota := 0
|
||||
if modelPrice == -1 {
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
} else {
|
||||
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
@@ -443,10 +491,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
useTimeSeconds := time.Now().Unix() - startTime.Unix()
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
|
||||
var logContent string
|
||||
if modelPrice == -1 {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
|
||||
}
|
||||
logModel := textRequest.Model
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
|
||||
}
|
||||
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
//if quota != 0 {
|
||||
@@ -539,6 +599,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeGemini:
|
||||
if textRequest.Stream {
|
||||
err, responseText := geminiChatStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeZhipu:
|
||||
if isStream {
|
||||
err, usage := zhipuStreamHandler(c, resp)
|
||||
|
||||
@@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
|
||||
}
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
|
||||
config, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
|
||||
|
||||
shortSide := config.Width
|
||||
otherSide := config.Height
|
||||
log.Printf("width: %d, height: %d", config.Width, config.Height)
|
||||
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
|
||||
// 缩放倍数
|
||||
scale := 1.0
|
||||
if config.Height < shortSide {
|
||||
@@ -160,10 +161,27 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
||||
} else {
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == "image_url" {
|
||||
imageTokenNum, err := getImageToken(&m.ImageUrl)
|
||||
var imageTokenNum int
|
||||
if str, ok := m.ImageUrl.(string); ok {
|
||||
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
|
||||
} else {
|
||||
imageUrlMap := m.ImageUrl.(map[string]interface{})
|
||||
detail, ok := imageUrlMap["detail"]
|
||||
if ok {
|
||||
imageUrlMap["detail"] = detail.(string)
|
||||
} else {
|
||||
imageUrlMap["detail"] = "auto"
|
||||
}
|
||||
imageUrl := MessageImageUrl{
|
||||
Url: imageUrlMap["url"].(string),
|
||||
Detail: imageUrlMap["detail"].(string),
|
||||
}
|
||||
imageTokenNum, err = getImageToken(&imageUrl)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tokenNum += imageTokenNum
|
||||
log.Printf("image token num: %d", imageTokenNum)
|
||||
} else {
|
||||
@@ -177,12 +195,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
|
||||
}
|
||||
|
||||
func countTokenInput(input any, model string) int {
|
||||
switch input.(type) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return countTokenText(input.(string), model)
|
||||
return countTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range input.([]string) {
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return countTokenText(text, model)
|
||||
|
||||
@@ -33,7 +33,7 @@ type XunfeiChatRequest struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
|
||||
@@ -19,9 +19,9 @@ type Message struct {
|
||||
}
|
||||
|
||||
type MediaMessage struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl MessageImageUrl `json:"image_url,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl any `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type MessageImageUrl struct {
|
||||
@@ -29,6 +29,60 @@ type MessageImageUrl struct {
|
||||
Detail string `json:"detail"`
|
||||
}
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
||||
|
||||
func (m Message) ParseContent() []MediaMessage {
|
||||
var contentList []MediaMessage
|
||||
var stringContent string
|
||||
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: stringContent,
|
||||
})
|
||||
return contentList
|
||||
}
|
||||
var arrayContent []json.RawMessage
|
||||
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
|
||||
for _, contentItem := range arrayContent {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
detail, ok := subObj["detail"]
|
||||
if ok {
|
||||
subObj["detail"] = detail.(string)
|
||||
} else {
|
||||
subObj["detail"] = "auto"
|
||||
}
|
||||
contentList = append(contentList, MediaMessage{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageUrl: MessageImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
Detail: subObj["detail"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
@@ -53,7 +107,7 @@ type GeneralOpenAIRequest struct {
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
@@ -91,14 +145,14 @@ type AudioRequest struct {
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MaxTokens uint `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
MaxTokens uint `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
|
||||
@@ -152,6 +152,21 @@ func Register(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if exist {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户名已存在,或已注销",
|
||||
})
|
||||
return
|
||||
}
|
||||
affCode := user.AffCode // this code is the inviter's code, not the user's own code
|
||||
inviterId, _ := model.GetUserIdByAffCode(affCode)
|
||||
cleanUser := model.User{
|
||||
@@ -525,7 +540,7 @@ func DeleteUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = model.DeleteUserById(id)
|
||||
err = model.HardDeleteUserById(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
@@ -632,7 +647,7 @@ func ManageUser(c *gin.Context) {
|
||||
Username: req.Username,
|
||||
}
|
||||
// Fill attributes
|
||||
model.DB.Where(&user).First(&user)
|
||||
model.DB.Unscoped().Where(&user).First(&user)
|
||||
if user.Id == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -668,7 +683,7 @@ func ManageUser(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := user.Delete(); err != nil {
|
||||
if err := user.HardDelete(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
version: '3.4'
|
||||
|
||||
services:
|
||||
one-api:
|
||||
new-api:
|
||||
image: calciumion/new-api:latest
|
||||
container_name: one-api
|
||||
container_name: new-api
|
||||
restart: always
|
||||
command: --log-dir /app/logs
|
||||
ports:
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- ./data:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
|
||||
12
go.mod
12
go.mod
@@ -19,7 +19,7 @@ require (
|
||||
github.com/samber/lo v1.38.1
|
||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.17.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
@@ -44,13 +44,14 @@ require (
|
||||
github.com/gorilla/sessions v1.2.1 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.3.1 // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.1 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@@ -64,8 +65,9 @@ require (
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
14
go.sum
14
go.sum
@@ -81,6 +81,10 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
|
||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
|
||||
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
|
||||
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
@@ -108,6 +112,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
@@ -172,11 +178,15 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
|
||||
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -189,12 +199,16 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
|
||||
4
main.go
4
main.go
@@ -27,7 +27,7 @@ var indexPage []byte
|
||||
|
||||
func main() {
|
||||
common.SetupLogger()
|
||||
common.SysLog("One API " + common.Version + " started")
|
||||
common.SysLog("New API " + common.Version + " started")
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func main() {
|
||||
}
|
||||
go controller.AutomaticallyTestChannels(frequency)
|
||||
}
|
||||
go controller.UpdateMidjourneyTask()
|
||||
go controller.UpdateMidjourneyTaskBulk()
|
||||
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
|
||||
common.BatchUpdateEnabled = true
|
||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||
|
||||
@@ -91,11 +91,11 @@ func TokenAuth() func(c *gin.Context) {
|
||||
key = c.Request.Header.Get("mj-api-secret")
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts := strings.Split(key, "-")
|
||||
parts = strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
} else {
|
||||
key = strings.TrimPrefix(key, "sk-")
|
||||
parts := strings.Split(key, "-")
|
||||
parts = strings.Split(key, "-")
|
||||
key = parts[0]
|
||||
}
|
||||
token, err := model.ValidateUserToken(key)
|
||||
|
||||
@@ -107,6 +107,8 @@ func Distribute() func(c *gin.Context) {
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
|
||||
Group string `json:"group" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
|
||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||
Weight uint `json:"weight" gorm:"default:0;index"`
|
||||
}
|
||||
|
||||
func GetGroupModels(group string) []string {
|
||||
@@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
var abilities []Ability
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
@@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
} else {
|
||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||
err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
channel := Channel{}
|
||||
channel.Id = ability.ChannelId
|
||||
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
|
||||
if len(abilities) > 0 {
|
||||
// Randomly choose one
|
||||
weightSum := uint(0)
|
||||
for _, ability_ := range abilities {
|
||||
weightSum += ability_.Weight
|
||||
}
|
||||
if weightSum == 0 {
|
||||
// All weight is 0, randomly choose one
|
||||
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
|
||||
} else {
|
||||
// Randomly choose one
|
||||
weight := common.GetRandomInt(int(weightSum))
|
||||
for _, ability_ := range abilities {
|
||||
weight -= int(ability_.Weight)
|
||||
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
|
||||
if weight <= 0 {
|
||||
channel.Id = ability_.ChannelId
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
err = DB.First(&channel, "id = ?", channel.Id).Error
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
@@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
|
||||
ChannelId: channel.Id,
|
||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||
Priority: channel.Priority,
|
||||
Weight: uint(channel.GetWeight()),
|
||||
}
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
|
||||
@@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelsIDM map[int]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
func InitChannelCache() {
|
||||
@@ -149,10 +150,12 @@ func InitChannelCache() {
|
||||
groups[ability.Group] = true
|
||||
}
|
||||
newGroup2model2channels := make(map[string]map[string][]*Channel)
|
||||
newChannelsIDM := make(map[int]*Channel)
|
||||
for group := range groups {
|
||||
newGroup2model2channels[group] = make(map[string][]*Channel)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
newChannelsIDM[channel.Id] = channel
|
||||
groups := strings.Split(channel.Group, ",")
|
||||
for _, group := range groups {
|
||||
models := strings.Split(channel.Models, ",")
|
||||
@@ -177,6 +180,7 @@ func InitChannelCache() {
|
||||
|
||||
channelSyncLock.Lock()
|
||||
group2model2channels = newGroup2model2channels
|
||||
channelsIDM = newChannelsIDM
|
||||
channelSyncLock.Unlock()
|
||||
common.SysLog("channels synced from database")
|
||||
}
|
||||
@@ -194,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
// if memory cache is disabled, get channel directly from database
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
}
|
||||
@@ -214,6 +219,41 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
}
|
||||
}
|
||||
}
|
||||
idx := rand.Intn(endIdx)
|
||||
return channels[idx], nil
|
||||
// Calculate the total weight of all channels up to endIdx
|
||||
totalWeight := 0
|
||||
for _, channel := range channels[:endIdx] {
|
||||
totalWeight += channel.GetWeight()
|
||||
}
|
||||
|
||||
if totalWeight == 0 {
|
||||
// If all weights are 0, select a channel randomly
|
||||
return channels[rand.Intn(endIdx)], nil
|
||||
}
|
||||
|
||||
// Generate a random value in the range [0, totalWeight)
|
||||
randomWeight := rand.Intn(totalWeight)
|
||||
|
||||
// Find a channel based on its weight
|
||||
for _, channel := range channels[:endIdx] {
|
||||
randomWeight -= channel.GetWeight()
|
||||
if randomWeight <= 0 {
|
||||
return channel, nil
|
||||
}
|
||||
}
|
||||
// return the last channel if no channel is found
|
||||
return channels[endIdx-1], nil
|
||||
}
|
||||
|
||||
func CacheGetChannel(id int) (*Channel, error) {
|
||||
if !common.MemoryCacheEnabled {
|
||||
return GetChannelById(id, true)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
|
||||
c, ok := channelsIDM[id]
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type Channel struct {
|
||||
Balance float64 `json:"balance"` // in USD
|
||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||
Models string `json:"models"`
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
Group string `json:"group" gorm:"type:varchar(255);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
@@ -86,6 +86,26 @@ func BatchInsertChannels(channels []Channel) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func BatchDeleteChannels(ids []int) error {
|
||||
//使用事务 删除channel表和channel_ability表
|
||||
tx := DB.Begin()
|
||||
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
|
||||
if err != nil {
|
||||
// 回滚事务
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
|
||||
if err != nil {
|
||||
// 回滚事务
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
// 提交事务
|
||||
tx.Commit()
|
||||
return err
|
||||
}
|
||||
|
||||
func (channel *Channel) GetPriority() int64 {
|
||||
if channel.Priority == nil {
|
||||
return 0
|
||||
@@ -93,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
|
||||
return *channel.Priority
|
||||
}
|
||||
|
||||
func (channel *Channel) GetWeight() int {
|
||||
if channel.Weight == nil {
|
||||
return 0
|
||||
}
|
||||
return int(*channel.Weight)
|
||||
}
|
||||
|
||||
func (channel *Channel) GetBaseURL() string {
|
||||
if channel.BaseURL == nil {
|
||||
return ""
|
||||
|
||||
@@ -52,6 +52,14 @@ func chooseDB() (*gorm.DB, error) {
|
||||
}
|
||||
// Use MySQL
|
||||
common.SysLog("using MySQL as database")
|
||||
// check parseTime
|
||||
if !strings.Contains(dsn, "parseTime") {
|
||||
if strings.Contains(dsn, "?") {
|
||||
dsn += "&parseTime=true"
|
||||
} else {
|
||||
dsn += "?parseTime=true"
|
||||
}
|
||||
}
|
||||
return gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
|
||||
@@ -131,3 +131,9 @@ func (midjourney *Midjourney) Update() error {
|
||||
err = DB.Save(midjourney).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("mj_id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
|
||||
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
|
||||
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
|
||||
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
@@ -220,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
err = common.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = common.UpdateGroupRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = common.UpdateModelPriceByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
case "ChatLink":
|
||||
|
||||
@@ -45,7 +45,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
if err == nil {
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
return nil, errors.New("该令牌额度已用尽 token.Status == common.TokenStatusExhausted " + key)
|
||||
} else if token.Status == common.TokenStatusExpired {
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
return nil, errors.New(fmt.Sprintf("%s 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", token.Key, token.RemainQuota))
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@@ -11,26 +11,51 @@ import (
|
||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
|
||||
// Otherwise, the sensitive information will be saved on local storage in plain text!
|
||||
type User struct {
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
Id int `json:"id"`
|
||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
|
||||
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
|
||||
Quota int `json:"quota" gorm:"type:int;default:0"`
|
||||
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
|
||||
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
|
||||
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
|
||||
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
|
||||
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
|
||||
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index"`
|
||||
}
|
||||
|
||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
var user User
|
||||
|
||||
// err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
|
||||
// check email if empty
|
||||
var err error
|
||||
if email == "" {
|
||||
err = DB.Unscoped().First(&user, "username = ?", username).Error
|
||||
} else {
|
||||
err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// not exist, return false, nil
|
||||
return false, nil
|
||||
}
|
||||
// other error, return false, err
|
||||
return false, err
|
||||
}
|
||||
// exist, return true, nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetMaxUserId() int {
|
||||
@@ -40,7 +65,7 @@ func GetMaxUserId() int {
|
||||
}
|
||||
|
||||
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
||||
err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
@@ -80,6 +105,14 @@ func DeleteUserById(id int) (err error) {
|
||||
return user.Delete()
|
||||
}
|
||||
|
||||
func HardDeleteUserById(id int) error {
|
||||
if id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func inviteUser(inviterId int) (err error) {
|
||||
user, err := GetUserById(inviterId, true)
|
||||
if err != nil {
|
||||
@@ -183,6 +216,14 @@ func (user *User) Delete() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (user *User) HardDelete() error {
|
||||
if user.Id == 0 {
|
||||
return errors.New("id 为空!")
|
||||
}
|
||||
err := DB.Unscoped().Delete(user).Error
|
||||
return err
|
||||
}
|
||||
|
||||
// ValidateAndFill check password & user status
|
||||
func (user *User) ValidateAndFill() (err error) {
|
||||
// When querying with struct, GORM will only query with non-zero fields,
|
||||
|
||||
@@ -83,6 +83,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.PUT("/", controller.UpdateChannel)
|
||||
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
|
||||
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
||||
channelRoute.POST("/batch", controller.DeleteChannelBatch)
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
tokenRoute.Use(middleware.UserAuth())
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.2.3",
|
||||
"react-fireworks": "^1.0.4",
|
||||
"react-router-dom": "^6.3.0",
|
||||
"react-scripts": "5.0.1",
|
||||
"react-toastify": "^9.0.8",
|
||||
|
||||
@@ -74,6 +74,11 @@ function renderBalance(type, balance) {
|
||||
|
||||
const ChannelsTable = () => {
|
||||
const columns = [
|
||||
// {
|
||||
// title: '',
|
||||
// dataIndex: 'checkbox',
|
||||
// className: 'checkbox',
|
||||
// },
|
||||
{
|
||||
title: 'ID',
|
||||
dataIndex: 'id',
|
||||
@@ -158,11 +163,30 @@ const ChannelsTable = () => {
|
||||
<div>
|
||||
<InputNumber
|
||||
style={{width: 70}}
|
||||
name='name'
|
||||
name='priority'
|
||||
onChange={value => {
|
||||
manageChannel(record.id, 'priority', record, value);
|
||||
}}
|
||||
defaultValue={record.priority}
|
||||
min={-999}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
title: '权重',
|
||||
dataIndex: 'weight',
|
||||
render: (text, record, index) => {
|
||||
return (
|
||||
<div>
|
||||
<InputNumber
|
||||
style={{width: 70}}
|
||||
name='weight'
|
||||
onChange={value => {
|
||||
manageChannel(record.id, 'weight', record, value);
|
||||
}}
|
||||
defaultValue={record.weight}
|
||||
min={0}
|
||||
/>
|
||||
</div>
|
||||
@@ -235,9 +259,11 @@ const ChannelsTable = () => {
|
||||
const [channelCount, setChannelCount] = useState(pageSize);
|
||||
const [groupOptions, setGroupOptions] = useState([]);
|
||||
const [showEdit, setShowEdit] = useState(false);
|
||||
const [enableBatchDelete, setEnableBatchDelete] = useState(false);
|
||||
const [editingChannel, setEditingChannel] = useState({
|
||||
id: undefined,
|
||||
});
|
||||
const [selectedChannels, setSelectedChannels] = useState([]);
|
||||
|
||||
const removeRecord = id => {
|
||||
let newDataSource = [...channels];
|
||||
@@ -484,6 +510,27 @@ const ChannelsTable = () => {
|
||||
setUpdatingBalance(false);
|
||||
};
|
||||
|
||||
const batchDeleteChannels = async () => {
|
||||
if (selectedChannels.length === 0) {
|
||||
showError('请先选择要删除的通道!');
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
let ids = [];
|
||||
selectedChannels.forEach((channel) => {
|
||||
ids.push(channel.id);
|
||||
});
|
||||
const res = await API.post(`/api/channel/batch`, {ids: ids});
|
||||
const {success, message, data} = res.data;
|
||||
if (success) {
|
||||
showSuccess(`已删除 ${data} 个通道!`);
|
||||
await refresh();
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setLoading(false);
|
||||
}
|
||||
|
||||
const sortChannel = (key) => {
|
||||
if (channels.length === 0) return;
|
||||
setLoading(true);
|
||||
@@ -557,6 +604,7 @@ const ChannelsTable = () => {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
return (
|
||||
<>
|
||||
<EditChannel refresh={refresh} visible={showEdit} handleClose={closeEdit} editingChannel={editingChannel}/>
|
||||
@@ -583,16 +631,18 @@ const ChannelsTable = () => {
|
||||
</Form>
|
||||
<div style={{marginTop: 10, display: 'flex'}}>
|
||||
<Space>
|
||||
<Typography.Text strong>使用ID排序</Typography.Text>
|
||||
<Switch checked={idSort} label='使用ID排序' uncheckedText="关" aria-label="是否用ID排序" onChange={(v) => {
|
||||
localStorage.setItem('id-sort', v + '')
|
||||
setIdSort(v)
|
||||
loadChannels(0, pageSize, v)
|
||||
.then()
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
})
|
||||
}}></Switch>
|
||||
<Space>
|
||||
<Typography.Text strong>使用ID排序</Typography.Text>
|
||||
<Switch checked={idSort} label='使用ID排序' uncheckedText="关" aria-label="是否用ID排序" onChange={(v) => {
|
||||
localStorage.setItem('id-sort', v + '')
|
||||
setIdSort(v)
|
||||
loadChannels(0, pageSize, v)
|
||||
.then()
|
||||
.catch((reason) => {
|
||||
showError(reason);
|
||||
})
|
||||
}}></Switch>
|
||||
</Space>
|
||||
</Space>
|
||||
</div>
|
||||
|
||||
@@ -607,7 +657,15 @@ const ChannelsTable = () => {
|
||||
handlePageSizeChange(size).then()
|
||||
},
|
||||
onPageChange: handlePageChange,
|
||||
}} loading={loading} onRow={handleRow}/>
|
||||
}} loading={loading} onRow={handleRow} rowSelection={
|
||||
enableBatchDelete ?
|
||||
{
|
||||
onChange: (selectedRowKeys, selectedRows) => {
|
||||
// console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows);
|
||||
setSelectedChannels(selectedRows);
|
||||
},
|
||||
} : null
|
||||
}/>
|
||||
<div style={{display: isMobile()?'':'flex', marginTop: isMobile()?0:-45, zIndex: 999, position: 'relative', pointerEvents: 'none'}}>
|
||||
<Space style={{pointerEvents: 'auto'}}>
|
||||
<Button theme='light' type='primary' style={{marginRight: 8}} onClick={
|
||||
@@ -622,7 +680,7 @@ const ChannelsTable = () => {
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={testAllChannels}
|
||||
position={isMobile()?'top':''}
|
||||
position={isMobile()?'top':'top'}
|
||||
>
|
||||
<Button theme='light' type='warning' style={{marginRight: 8}}>测试所有已启用通道</Button>
|
||||
</Popconfirm>
|
||||
@@ -648,6 +706,24 @@ const ChannelsTable = () => {
|
||||
|
||||
{/*</div>*/}
|
||||
</div>
|
||||
<div style={{marginTop: 20}}>
|
||||
<Space>
|
||||
<Typography.Text strong>开启批量删除</Typography.Text>
|
||||
<Switch label='开启批量删除' uncheckedText="关" aria-label="是否开启批量删除" onChange={(v) => {
|
||||
setEnableBatchDelete(v)
|
||||
}}></Switch>
|
||||
<Popconfirm
|
||||
title="确定是否要删除所选通道?"
|
||||
content="此修改将不可逆"
|
||||
okType={'danger'}
|
||||
onConfirm={batchDeleteChannels}
|
||||
disabled={!enableBatchDelete}
|
||||
position={'top'}
|
||||
>
|
||||
<Button disabled={!enableBatchDelete} theme='light' type='danger' style={{marginRight: 8}}>删除所选通道</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, {useContext, useEffect, useState} from 'react';
|
||||
import React, {useContext, useEffect, useRef, useState} from 'react';
|
||||
import {Link, useNavigate} from 'react-router-dom';
|
||||
import {UserContext} from '../context/User';
|
||||
|
||||
@@ -6,18 +6,12 @@ import {Button, Container, Icon, Menu, Segment} from 'semantic-ui-react';
|
||||
import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers';
|
||||
import '../index.css';
|
||||
|
||||
import fireworks from 'react-fireworks';
|
||||
|
||||
import {
|
||||
IconAt,
|
||||
IconHistogram,
|
||||
IconGift,
|
||||
IconKey,
|
||||
IconUser,
|
||||
IconLayers,
|
||||
IconHelpCircle,
|
||||
IconCreditCard,
|
||||
IconSemiLogo,
|
||||
IconHome,
|
||||
IconImage
|
||||
IconHelpCircle
|
||||
} from '@douyinfe/semi-icons';
|
||||
import {Nav, Avatar, Dropdown, Layout, Switch} from '@douyinfe/semi-ui';
|
||||
import {stringToColor} from "../helpers/render";
|
||||
@@ -49,6 +43,8 @@ const HeaderBar = () => {
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
var themeMode = localStorage.getItem('theme-mode');
|
||||
const currentDate = new Date();
|
||||
const isNewYear = currentDate.getMonth() === 0 && currentDate.getDate() === 1;
|
||||
|
||||
async function logout() {
|
||||
setShowSidebar(false);
|
||||
@@ -59,10 +55,24 @@ const HeaderBar = () => {
|
||||
navigate('/login');
|
||||
}
|
||||
|
||||
const handleNewYearClick = () => {
|
||||
fireworks.init("root",{});
|
||||
fireworks.start();
|
||||
setTimeout(() => {
|
||||
fireworks.stop();
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 10000);
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (themeMode === 'dark') {
|
||||
switchMode(true);
|
||||
}
|
||||
if (isNewYear) {
|
||||
console.log('Happy New Year!');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const switchMode = (model) => {
|
||||
@@ -105,6 +115,19 @@ const HeaderBar = () => {
|
||||
}}
|
||||
footer={
|
||||
<>
|
||||
{isNewYear &&
|
||||
// happy new year
|
||||
<Dropdown
|
||||
position="bottomRight"
|
||||
render={
|
||||
<Dropdown.Menu>
|
||||
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
<Nav.Item itemKey={'new-year'} text={'🏮'}/>
|
||||
</Dropdown>
|
||||
}
|
||||
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
|
||||
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
|
||||
{userState.user ?
|
||||
|
||||
@@ -2,7 +2,19 @@ import React, {useEffect, useState} from 'react';
|
||||
import {Label} from 'semantic-ui-react';
|
||||
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
|
||||
|
||||
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal } from '@douyinfe/semi-ui';
|
||||
import {
|
||||
Table,
|
||||
Avatar,
|
||||
Tag,
|
||||
Form,
|
||||
Button,
|
||||
Layout,
|
||||
Select,
|
||||
Popover,
|
||||
Modal,
|
||||
ImagePreview,
|
||||
Typography
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {ITEMS_PER_PAGE} from '../constants';
|
||||
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
|
||||
|
||||
@@ -194,19 +206,16 @@ const LogsTable = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
text.length > 10 ?
|
||||
<>
|
||||
{text.slice(0, 10)}
|
||||
<Button
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
查看全部
|
||||
</Button>
|
||||
</>
|
||||
: text
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ width: 100 }}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
},
|
||||
@@ -220,19 +229,16 @@ const LogsTable = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
text.length > 10 ?
|
||||
<>
|
||||
{text.slice(0, 10)}
|
||||
<Button
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
查看全部
|
||||
</Button>
|
||||
</>
|
||||
: text
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ width: 100 }}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
},
|
||||
@@ -246,19 +252,16 @@ const LogsTable = () => {
|
||||
}
|
||||
|
||||
return (
|
||||
text.length > 10 ?
|
||||
<>
|
||||
{text.slice(0, 10)}
|
||||
<Button
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
查看全部
|
||||
</Button>
|
||||
</>
|
||||
: text
|
||||
<Typography.Text
|
||||
ellipsis={{ showTooltip: true }}
|
||||
style={{ width: 100 }}
|
||||
onClick={() => {
|
||||
setModalContent(text);
|
||||
setIsModalOpen(true);
|
||||
}}
|
||||
>
|
||||
{text}
|
||||
</Typography.Text>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -414,15 +417,11 @@ const LogsTable = () => {
|
||||
>
|
||||
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
|
||||
</Modal>
|
||||
{/* 模态框组件,用于展示图片 */}
|
||||
<Modal
|
||||
title="图片预览"
|
||||
visible={isModalOpenurl}
|
||||
onCancel={() => setIsModalOpenurl(false)}
|
||||
footer={null} // 模态框不显示底部按钮
|
||||
>
|
||||
<img src={modalImageUrl} style={{ width: '100%' }} alt="结果图片" />
|
||||
</Modal>
|
||||
<ImagePreview
|
||||
src={modalImageUrl}
|
||||
visible={isModalOpenurl}
|
||||
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
|
||||
/>
|
||||
|
||||
</Layout>
|
||||
</>
|
||||
|
||||
@@ -10,6 +10,7 @@ const OperationSetting = () => {
|
||||
QuotaRemindThreshold: 0,
|
||||
PreConsumedQuota: 0,
|
||||
ModelRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
TopUpLink: '',
|
||||
ChatLink: '',
|
||||
@@ -30,7 +31,7 @@ const OperationSetting = () => {
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
|
||||
if (item.key === 'ModelRatio' || item.key === 'GroupRatio'|| item.key === 'ModelPrice') {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
newInputs[item.key] = item.value;
|
||||
@@ -97,6 +98,13 @@ const OperationSetting = () => {
|
||||
}
|
||||
await updateOption('GroupRatio', inputs.GroupRatio);
|
||||
}
|
||||
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
|
||||
if (!verifyJSON(inputs.ModelPrice)) {
|
||||
showError('模型固定价格不是合法的 JSON 字符串');
|
||||
return;
|
||||
}
|
||||
await updateOption('ModelPrice', inputs.ModelPrice);
|
||||
}
|
||||
break;
|
||||
case 'quota':
|
||||
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
|
||||
@@ -315,6 +323,17 @@ const OperationSetting = () => {
|
||||
<Header as='h3'>
|
||||
倍率设置
|
||||
</Header>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
|
||||
name='ModelPrice'
|
||||
onChange={handleInputChange}
|
||||
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
|
||||
autoComplete='new-password'
|
||||
value={inputs.ModelPrice}
|
||||
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1,一次消耗0.1刀'
|
||||
/>
|
||||
</Form.Group>
|
||||
<Form.Group widths='equal'>
|
||||
<Form.TextArea
|
||||
label='模型倍率'
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
import {getQuotaPerUnit, renderQuota, renderQuotaWithPrompt, stringToColor} from "../helpers/render";
|
||||
import EditToken from "../pages/Token/EditToken";
|
||||
import EditUser from "../pages/User/EditUser";
|
||||
import passwordResetConfirm from "./PasswordResetConfirm";
|
||||
|
||||
const PersonalSetting = () => {
|
||||
const [userState, userDispatch] = useContext(UserContext);
|
||||
@@ -29,9 +30,12 @@ const PersonalSetting = () => {
|
||||
wechat_verification_code: '',
|
||||
email_verification_code: '',
|
||||
email: '',
|
||||
self_account_deletion_confirmation: ''
|
||||
self_account_deletion_confirmation: '',
|
||||
set_new_password: '',
|
||||
set_new_password_confirmation: '',
|
||||
});
|
||||
const [status, setStatus] = useState({});
|
||||
const [showChangePasswordModal, setShowChangePasswordModal] = useState(false);
|
||||
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
|
||||
const [showEmailBindModal, setShowEmailBindModal] = useState(false);
|
||||
const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false);
|
||||
@@ -180,6 +184,27 @@ const PersonalSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const changePassword = async () => {
|
||||
if (inputs.set_new_password !== inputs.set_new_password_confirmation) {
|
||||
showError('两次输入的密码不一致!');
|
||||
return;
|
||||
}
|
||||
const res = await API.put(
|
||||
`/api/user/self`,
|
||||
{
|
||||
password: inputs.set_new_password
|
||||
}
|
||||
);
|
||||
const {success, message} = res.data;
|
||||
if (success) {
|
||||
showSuccess('密码修改成功!');
|
||||
setShowWeChatBindModal(false);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
setShowChangePasswordModal(false);
|
||||
};
|
||||
|
||||
const transfer = async () => {
|
||||
if (transferAmount < getQuotaPerUnit()) {
|
||||
showError('划转金额最低为' + renderQuota(getQuotaPerUnit()));
|
||||
@@ -420,6 +445,9 @@ const PersonalSetting = () => {
|
||||
<Space>
|
||||
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
|
||||
<Button onClick={() => {
|
||||
setShowChangePasswordModal(true);
|
||||
}}>修改密码</Button>
|
||||
<Button type={'danger'} onClick={() => {
|
||||
setShowAccountDeleteModal(true);
|
||||
}}>删除个人账户</Button>
|
||||
</Space>
|
||||
@@ -543,6 +571,39 @@ const PersonalSetting = () => {
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
<Modal
|
||||
onCancel={() => setShowChangePasswordModal(false)}
|
||||
visible={showChangePasswordModal}
|
||||
size={'small'}
|
||||
centered={true}
|
||||
onOk={changePassword}
|
||||
>
|
||||
<div style={{marginTop: 20}}>
|
||||
<Input
|
||||
name='set_new_password'
|
||||
placeholder='新密码'
|
||||
value={inputs.set_new_password}
|
||||
onChange={(value)=>handleInputChange('set_new_password', value)}
|
||||
/>
|
||||
<Input
|
||||
style={{marginTop: 20}}
|
||||
name='set_new_password_confirmation'
|
||||
placeholder='确认新密码'
|
||||
value={inputs.set_new_password_confirmation}
|
||||
onChange={(value)=>handleInputChange('set_new_password_confirmation', value)}
|
||||
/>
|
||||
{turnstileEnabled ? (
|
||||
<Turnstile
|
||||
sitekey={turnstileSiteKey}
|
||||
onVerify={(token) => {
|
||||
setTurnstileToken(token);
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
</div>
|
||||
|
||||
</Layout.Content>
|
||||
|
||||
@@ -72,32 +72,49 @@ const UsersTable = () => {
|
||||
}, {
|
||||
title: '状态', dataIndex: 'status', render: (text, record, index) => {
|
||||
return (<div>
|
||||
{renderStatus(text)}
|
||||
{record.DeletedAt !== null? <Tag color='red'>已注销</Tag> : renderStatus(text)}
|
||||
</div>);
|
||||
},
|
||||
}, {
|
||||
title: '', dataIndex: 'operate', render: (text, record, index) => (<div>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'promote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
|
||||
</Popconfirm>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'demote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
|
||||
</Popconfirm>
|
||||
{
|
||||
record.DeletedAt !== null ? <></>:
|
||||
<>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'promote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
|
||||
</Popconfirm>
|
||||
<Popconfirm
|
||||
title="确定?"
|
||||
okType={'warning'}
|
||||
onConfirm={() => {
|
||||
manageUser(record.username, 'demote', record)
|
||||
}}
|
||||
>
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
|
||||
</Popconfirm>
|
||||
{record.status === 1 ?
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'disable', record)
|
||||
}}>禁用</Button> :
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'enable', record);
|
||||
}} disabled={record.status === 3}>启用</Button>}
|
||||
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
|
||||
setEditingUser(record);
|
||||
setShowEditUser(true);
|
||||
}}>编辑</Button>
|
||||
</>
|
||||
|
||||
}
|
||||
<Popconfirm
|
||||
title="确定是否要删除此用户?"
|
||||
content="此修改将不可逆"
|
||||
content="硬删除,此修改将不可逆"
|
||||
okType={'danger'}
|
||||
position={'left'}
|
||||
onConfirm={() => {
|
||||
@@ -108,17 +125,6 @@ const UsersTable = () => {
|
||||
>
|
||||
<Button theme='light' type='danger' style={{marginRight: 1}}>删除</Button>
|
||||
</Popconfirm>
|
||||
{record.status === 1 ?
|
||||
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'disable', record)
|
||||
}}>禁用</Button> :
|
||||
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
|
||||
manageUser(record.username, 'enable', record);
|
||||
}} disabled={record.status === 3}>启用</Button>}
|
||||
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
|
||||
setEditingUser(record);
|
||||
setShowEditUser(true);
|
||||
}}>编辑</Button>
|
||||
</div>),
|
||||
},];
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
export const CHANNEL_OPTIONS = [
|
||||
{ key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI' },
|
||||
{ key: 24, text: 'Midjourney Proxy', value: 24, color: 'light-blue', label: 'Midjourney Proxy' },
|
||||
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude' },
|
||||
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI' },
|
||||
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2' },
|
||||
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆' },
|
||||
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问' },
|
||||
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知' },
|
||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM' },
|
||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
|
||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue', label: '知识库:FastGPT' },
|
||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple', label: '知识库:AI Proxy' },
|
||||
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
|
||||
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
|
||||
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude'},
|
||||
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI'},
|
||||
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2'},
|
||||
{key: 24, text: 'Google Gemini', value: 24, color: 'orange', label: 'Google Gemini'},
|
||||
{key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆'},
|
||||
{key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问'},
|
||||
{key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知'},
|
||||
{key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM'},
|
||||
{key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑'},
|
||||
{key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元'},
|
||||
{key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道'},
|
||||
{key: 22, text: '知识库:FastGPT', value: 22, color: 'blue', label: '知识库:FastGPT'},
|
||||
{key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple', label: '知识库:AI Proxy'},
|
||||
];
|
||||
|
||||
@@ -86,6 +86,9 @@ const EditChannel = (props) => {
|
||||
case 23:
|
||||
localModels = ['hunyuan'];
|
||||
break;
|
||||
case 24:
|
||||
localModels = ['gemini-pro'];
|
||||
break;
|
||||
}
|
||||
setInputs((inputs) => ({...inputs, models: localModels}));
|
||||
}
|
||||
|
||||
@@ -74,12 +74,17 @@ const TopUp = () => {
|
||||
const {message, data} = res.data;
|
||||
// showInfo(message);
|
||||
if (message === 'success') {
|
||||
|
||||
let params = data
|
||||
let url = res.data.url
|
||||
let form = document.createElement('form')
|
||||
form.action = url
|
||||
form.method = 'POST'
|
||||
form.target = '_blank'
|
||||
// 判断是否为safari浏览器
|
||||
let isSafari = navigator.userAgent.indexOf("Safari") > -1 && navigator.userAgent.indexOf("Chrome") < 1;
|
||||
if (!isSafari) {
|
||||
form.target = '_blank'
|
||||
}
|
||||
for (let key in params) {
|
||||
let input = document.createElement('input')
|
||||
input.type = 'hidden'
|
||||
|
||||
Reference in New Issue
Block a user