Compare commits

...

43 Commits

Author SHA1 Message Date
CaIon
bdd611fd33 feat: 加入渠道加权随机功能 2023-12-27 19:00:47 +08:00
CaIon
1a8a24698f Happy New Year 2023-12-27 18:04:02 +08:00
CaIon
c09f3b9282 enable parseTime=true 2023-12-27 16:47:32 +08:00
CaIon
e8bcf60f0a 修复不能删除注销用户的问题 2023-12-27 16:46:18 +08:00
CaIon
14592f9758 support gemini-pro-vision 2023-12-27 16:32:54 +08:00
CaIon
4036355fae bump golang.org/x/crypto from 0.14.0 to 0.17.0 2023-12-27 16:32:42 +08:00
CaIon
fa2efb7357 bump golang.org/x/crypto from 0.14.0 to 0.17.0 2023-12-27 15:39:29 +08:00
CaIon
7a8344c40a 修复用户管理显示问题 2023-12-27 15:35:16 +08:00
CaIon
6ad2544415 enable mysql parseTime 2023-12-27 15:33:35 +08:00
CaIon
7cd1261a81 Revert "fix delete user"
This reverts commit 07fe5a02
2023-12-27 15:31:28 +08:00
CaIon
45ed973f1c fix delete user 2023-12-27 15:15:07 +08:00
CaIon
07fe5a02af fix delete user 2023-12-27 14:36:19 +08:00
CaIon
7a4969c238 update Dockerfile 2023-12-27 00:19:25 +08:00
CaIon
ad6842da7f 用户注销后禁止再次注册 2023-12-27 00:19:11 +08:00
CaIon
3c52a0991f 恢复渠道优先级可设置为负数 2023-12-26 23:49:57 +08:00
Xyfacai
fd4ef086dc fix: 优化 mj 获取进度 2023-12-23 23:14:58 +08:00
CaIon
7c4719b6ee update README.md 2023-12-21 23:12:31 +08:00
CaIon
b9d040cf52 fix VERSION 2023-12-21 23:11:52 +08:00
CaIon
6e8ff8c057 fix gemini 2023-12-21 23:08:09 +08:00
CaIon
676dc95793 update README.md 2023-12-21 20:21:15 +08:00
CaIon
1c06bddafe 优化gpt-4-gizmo-*日志 2023-12-21 20:20:09 +08:00
CaIon
3475643257 支持设置模型按次计费 2023-12-21 20:14:04 +08:00
CaIon
45e1042e58 update GeneralOpenAIRequest 2023-12-20 21:20:09 +08:00
CaIon
f5a36a05e5 log输出更多有效信息 2023-12-20 20:53:36 +08:00
CaIon
5730c69385 个人设置添加修改密码功能 2023-12-20 20:53:15 +08:00
CaIon
2d33283afb fix gemini 2023-12-19 12:53:56 +08:00
CaIon
e057c0e42e 添加gemini支持 2023-12-18 23:46:05 +08:00
CaIon
b3f1da44dd 修复vision格式问题 2023-12-18 23:46:05 +08:00
CaIon
f048cefed1 update README.md 2023-12-18 23:46:05 +08:00
CaIon
0226d94ea6 修改docker-compose.yml 2023-12-18 23:46:04 +08:00
CaIon
42469cb782 修复无法指定渠道id的问题 2023-12-14 16:43:20 +08:00
CaIon
e1da1e31d5 添加批量删除渠道功能 2023-12-14 16:35:03 +08:00
CaIon
0fdd4fc6e3 美化绘画IU 2023-12-14 15:16:51 +08:00
CaIon
261dc43c4e 修复 测试所有渠道按钮 失效的问题 2023-12-14 15:16:27 +08:00
CaIon
6463e0539f Warning when SESSION_SECRET is 'random_string' 2023-12-14 11:10:14 +08:00
CaIon
c5f08a757d Merge remote-tracking branch 'public/main' into latest 2023-12-14 11:02:08 +08:00
CaIon
8a9bd08d66 update README.md 2023-12-14 11:01:53 +08:00
Calcium-Ion
751c33a6c0 Merge pull request #30 from YOMIkio/email
fix: add Date header for email
2023-12-14 10:53:39 +08:00
CaIon
57f664d0fa 修复Safari无法拉起支付的问题 2023-12-13 17:02:33 +08:00
liyujie
cdcbebce6d fix: add Date header for email 2023-12-13 16:51:19 +08:00
CaIon
a29e3765f4 修复gpts无法设置倍率的问题 2023-12-13 00:29:53 +08:00
CaIon
766e20719d update README.md 2023-12-12 16:41:51 +08:00
CaIon
4b93f185bb 添加gpt-4-1106-vision-preview模型 2023-12-12 16:41:37 +08:00
44 changed files with 1434 additions and 230 deletions

View File

@@ -1,5 +1,5 @@
# Neko API
# New API
> [!NOTE]
> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。
@@ -29,7 +29,21 @@
+ 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况方便二次分销
5. 渠道显示已使用额度,支持指定组织访问
6. 分页支持选择每页显示数量
7. 支持gpt-4-1106-vision-previewdall-e-3tts-1
7. 支持 gpt-4-1106-vision-previewdall-e-3tts-1
8. 支持第三方模型 **gps** gpt-4-gizmo-*在渠道中添加自定义模型gpt-4-gizmo-*即可
9. 兼容原版One API的数据库可直接使用原版数据库one-api.db
10. 支持模型按次数收费,可在 系统设置-运营设置 中设置
11. 支持gemini-pro模型
## 部署
### 基于 Docker 进行部署
```shell
# 使用 SQLite 的部署命令:
docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数。
# 例如:
docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/new-api:/data calciumion/new-api:latest
```
## 交流群
<img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="500">

View File

@@ -11,7 +11,7 @@ import (
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var SystemName = "New API"
var ServerAddress = "http://localhost:3000"
var PayAddress = ""
var EpayId = ""
@@ -190,6 +190,7 @@ const (
ChannelTypeAIProxyLibrary = 21
ChannelTypeFastGPT = 22
ChannelTypeTencent = 23
ChannelTypeGemini = 24
)
var ChannelBaseURLs = []string{

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/smtp"
"strings"
"time"
)
func SendEmail(subject string, receiver string, content string) error {
@@ -16,8 +17,9 @@ func SendEmail(subject string, receiver string, content string) error {
mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+
"Subject: %s\r\n"+
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, SystemName, SMTPFrom, encodedSubject, content))
receiver, SystemName, SMTPFrom, encodedSubject, time.Now().Format(time.RFC1123Z), content))
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";")

View File

@@ -12,7 +12,7 @@ import (
"strings"
)
func DecodeBase64ImageData(base64String string) (image.Config, error) {
func DecodeBase64ImageData(base64String string) (image.Config, string, error) {
// 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 {
base64String = base64String[idx+1:]
@@ -22,20 +22,51 @@ func DecodeBase64ImageData(base64String string) (image.Config, error) {
decodedData, err := base64.StdEncoding.DecodeString(base64String)
if err != nil {
fmt.Println("Error: Failed to decode base64 string")
return image.Config{}, err
return image.Config{}, "", err
}
// 创建一个bytes.Buffer用于存储解码后的数据
reader := bytes.NewReader(decodedData)
config, err := getImageConfig(reader)
return config, err
config, format, err := getImageConfig(reader)
return config, format, err
}
func DecodeUrlImageData(imageUrl string) (image.Config, error) {
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
func DecodeUrlImageData(imageUrl string) (image.Config, string, error) {
response, err := http.Get(imageUrl)
if err != nil {
SysLog(fmt.Sprintf("fail to get image from url: %s", err.Error()))
return image.Config{}, err
return image.Config{}, "", err
}
// 限制读取的字节数,防止下载整个图片
@@ -45,14 +76,14 @@ func DecodeUrlImageData(imageUrl string) (image.Config, error) {
// log.Fatal(err)
//}
//log.Printf("%x", data)
config, err := getImageConfig(limitReader)
config, format, err := getImageConfig(limitReader)
response.Body.Close()
return config, err
return config, format, err
}
func getImageConfig(reader io.Reader) (image.Config, error) {
func getImageConfig(reader io.Reader) (image.Config, string, error) {
// 读取图片的头部信息来获取图片尺寸
config, _, err := image.DecodeConfig(reader)
config, format, err := image.DecodeConfig(reader)
if err != nil {
err = errors.New(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error()))
SysLog(err.Error())
@@ -61,9 +92,10 @@ func getImageConfig(reader io.Reader) (image.Config, error) {
err = errors.New(fmt.Sprintf("fail to decode image config(webp): %s", err.Error()))
SysLog(err.Error())
}
format = "webp"
}
if err != nil {
return image.Config{}, err
return image.Config{}, "", err
}
return config, nil
return config, format, nil
}

View File

@@ -16,7 +16,7 @@ var (
)
func printHelp() {
fmt.Println("One API " + Version + " - All in one API service for OpenAI API.")
fmt.Println("New API " + Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2023 JustSong. All rights reserved.")
fmt.Println("GitHub: https://github.com/songquanpeng/one-api")
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
@@ -36,7 +36,14 @@ func init() {
}
if os.Getenv("SESSION_SECRET") != "" {
SessionSecret = os.Getenv("SESSION_SECRET")
ss := os.Getenv("SESSION_SECRET")
if ss == "random_string" {
log.Println("WARNING: SESSION_SECRET is set to the default value 'random_string', please change it to a random string.")
log.Println("警告SESSION_SECRET被设置为默认值'random_string',请修改为随机字符串。")
log.Fatal("Please set SESSION_SECRET to a random string.")
} else {
SessionSecret = ss
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")

View File

@@ -15,6 +15,7 @@ import (
// 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{
"midjourney": 50,
"gpt-4-gizmo-*": 15,
"gpt-4": 15,
"gpt-4-0314": 15,
"gpt-4-0613": 15,
@@ -23,6 +24,7 @@ var ModelRatio = map[string]float64{
"gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75,
@@ -59,6 +61,8 @@ var ModelRatio = map[string]float64{
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
@@ -74,6 +78,35 @@ var ModelRatio = map[string]float64{
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
}
var ModelPrice = map[string]float64{
"gpt-4-gizmo-*": 0.1,
}
func ModelPrice2JSONString() string {
jsonBytes, err := json.Marshal(ModelPrice)
if err != nil {
SysError("error marshalling model price: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelPriceByJSONString(jsonStr string) error {
ModelPrice = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &ModelPrice)
}
func GetModelPrice(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
price, ok := ModelPrice[name]
if !ok {
//SysError("model price not found: " + name)
return -1
}
return price
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
@@ -88,6 +121,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
}
func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*"
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)

View File

@@ -168,6 +168,11 @@ func GetRandomString(length int) string {
return string(key)
}
func GetRandomInt(max int) int {
//rand.Seed(time.Now().UnixNano())
return rand.Intn(max)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}

View File

@@ -29,6 +29,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
fallthrough
case common.ChannelType360:
fallthrough
case common.ChannelTypeGemini:
fallthrough
case common.ChannelTypeXunfei:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:

View File

@@ -151,6 +151,36 @@ func DeleteDisabledChannel(c *gin.Context) {
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func UpdateChannel(c *gin.Context) {
channel := model.Channel{}
err := c.ShouldBindJSON(&channel)

View File

@@ -16,8 +16,10 @@ import (
"time"
)
func UpdateMidjourneyTask() {
/*func UpdateMidjourneyTask() {
//revocer
//imageModel := "midjourney"
ctx := context.TODO()
imageModel := "midjourney"
defer func() {
if err := recover(); err != nil {
@@ -28,27 +30,28 @@ func UpdateMidjourneyTask() {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
log.Printf("检测到未完成的任务数有: %v", len(tasks))
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
for _, task := range tasks {
log.Printf("未完成的任务信息: %v", task)
common.LogInfo(ctx, fmt.Sprintf("未完成的任务信息: %v", task))
midjourneyChannel, err := model.GetChannelById(task.ChannelId, true)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
common.LogError(ctx, fmt.Sprintf("UpdateMidjourneyTask: %v", err))
task.FailReason = fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", task.ChannelId)
task.Status = "FAILURE"
task.Progress = "100%"
err := task.Update()
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
continue
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/%s/fetch", *midjourneyChannel.BaseURL, task.MjId)
log.Printf("requestUrl: %s", requestUrl)
common.LogInfo(ctx, fmt.Sprintf("requestUrl: %s", requestUrl))
req, err := http.NewRequest("GET", requestUrl, bytes.NewBuffer([]byte("")))
if err != nil {
log.Printf("UpdateMidjourneyTask error: %v", err)
common.LogInfo(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
@@ -111,7 +114,7 @@ func UpdateMidjourneyTask() {
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
log.Println(task.MjId + " 构建失败," + task.FailReason)
common.LogWarn(task.MjId + " 构建失败," + task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
@@ -126,8 +129,8 @@ func UpdateMidjourneyTask() {
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, 1, logContent)
logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
@@ -142,6 +145,180 @@ func UpdateMidjourneyTask() {
}
}
}
*/
func UpdateMidjourneyTaskBulk() {
//revocer
defer func() {
if err := recover(); err != nil {
log.Printf("UpdateMidjourneyTask panic: %v", err)
}
}()
imageModel := "midjourney"
ctx := context.TODO()
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) == 0 {
continue
}
common.LogInfo(ctx, fmt.Sprintf("检测到未完成的任务数有: %v", len(tasks)))
taskChannelM := make(map[int][]string)
taskM := make(map[string]*model.Midjourney)
for _, task := range tasks {
if task.MjId == "" {
continue
}
taskM[task.MjId] = task
taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.MjId)
}
if len(taskChannelM) == 0 {
continue
}
for channelId, taskIds := range taskChannelM {
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
if len(taskIds) == 0 {
continue
}
midjourneyChannel, err := model.CacheGetChannel(channelId)
if err != nil {
common.LogError(ctx, fmt.Sprintf("CacheGetChannel: %v", err))
err := model.MjBulkUpdate(taskIds, map[string]any{
"fail_reason": fmt.Sprintf("获取渠道信息失败请联系管理员渠道ID%d", channelId),
"status": "FAILURE",
"progress": "100%",
})
if err != nil {
common.LogInfo(ctx, fmt.Sprintf("UpdateMidjourneyTask error: %v", err))
}
continue
}
requestUrl := fmt.Sprintf("%s/mj/task/list-by-condition", *midjourneyChannel.BaseURL)
body, _ := json.Marshal(map[string]any{
"ids": taskIds,
})
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task error: %v", err))
continue
}
// 设置超时时间
timeout := time.Second * 5
ctx, cancel := context.WithTimeout(context.Background(), timeout)
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", midjourneyChannel.Key)
resp, err := httpClient.Do(req)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task Do req error: %v", err))
continue
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error: %v", err))
continue
}
var responseItems []Midjourney
err = json.Unmarshal(responseBody, &responseItems)
if err != nil {
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v", err))
continue
}
resp.Body.Close()
req.Body.Close()
cancel()
for _, responseItem := range responseItems {
task := taskM[responseItem.MjId]
if !checkMjTaskNeedUpdate(task, responseItem) {
continue
}
task.Code = 1
task.Progress = responseItem.Progress
task.PromptEn = responseItem.PromptEn
task.State = responseItem.State
task.SubmitTime = responseItem.SubmitTime
task.StartTime = responseItem.StartTime
task.FinishTime = responseItem.FinishTime
task.ImageUrl = responseItem.ImageUrl
task.Status = responseItem.Status
task.FailReason = responseItem.FailReason
if task.Progress != "100%" && responseItem.FailReason != "" {
common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason)
task.Progress = "100%"
err = model.CacheUpdateUserQuota(task.UserId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err = model.IncreaseUserQuota(task.UserId, quota)
if err != nil {
common.LogError(ctx, "fail to increase user quota: "+err.Error())
}
logContent := fmt.Sprintf("构图失败 %s补偿 %s", task.MjId, common.LogQuota(quota))
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
}
}
}
err = task.Update()
if err != nil {
common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error())
}
}
}
}
}
func checkMjTaskNeedUpdate(oldTask *model.Midjourney, newTask Midjourney) bool {
if oldTask.Code != 1 {
return true
}
if oldTask.Progress != newTask.Progress {
return true
}
if oldTask.PromptEn != newTask.PromptEn {
return true
}
if oldTask.State != newTask.State {
return true
}
if oldTask.SubmitTime != newTask.SubmitTime {
return true
}
if oldTask.StartTime != newTask.StartTime {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.ImageUrl != newTask.ImageUrl {
return true
}
if oldTask.Status != newTask.Status {
return true
}
if oldTask.FailReason != newTask.FailReason {
return true
}
if oldTask.FinishTime != newTask.FinishTime {
return true
}
if oldTask.Progress != "100%" && newTask.FailReason != "" {
return true
}
return false
}
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
@@ -187,6 +364,12 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if !strings.Contains(common.ServerAddress, "localhost") {
for i, midjourney := range logs {
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",

View File

@@ -261,6 +261,15 @@ func init() {
Root: "gpt-4-vision-preview",
Parent: nil,
},
{
Id: "gpt-4-1106-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-vision-preview",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
@@ -414,6 +423,24 @@ func init() {
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",

View File

@@ -18,7 +18,7 @@ type ClaudeMetadata struct {
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
MaxTokensToSample uint `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`

336
controller/relay-gemini.go Normal file
View File

@@ -0,0 +1,336 @@
package controller
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
//SafetySettings: []GeminiChatSafetySettings{
// {
// Category: "HARM_CATEGORY_HARASSMENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_HATE_SPEECH",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
// {
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
// Threshold: "BLOCK_ONLY_HIGH",
// },
//},
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: string(message.Content),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := common.GetImageFromUrl(part.ImageUrl.(MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
Index: i,
Message: Message{
Role: "assistant",
Content: content,
},
FinishReason: stopFinishReason,
}
content, _ = json.Marshal(candidate.Content.Parts[0].Text)
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = content
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &stopFinishReason
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
setEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse GeminiChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens := countTokenText(geminiResponse.GetResponseText(), model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -31,7 +31,7 @@ type PaLMChatRequest struct {
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {

View File

@@ -26,6 +26,7 @@ const (
APITypeXunfei
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
)
var httpClient *http.Client
@@ -119,6 +120,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeAIProxyLibrary
case common.ChannelTypeTencent:
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -180,6 +183,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
requestBaseURL = baseURL
}
version := "v1beta"
if c.GetString("api_version") != "" {
version = c.GetString("api_version")
}
action := "generateContent"
if textRequest.Stream {
action = "streamGenerateContent"
}
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
//log.Println(fullRequestURL)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
@@ -209,14 +231,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
modelPrice := common.GetModelPrice(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
var preConsumedQuota int
var ratio float64
var modelRatio float64
if modelPrice == -1 {
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + int(textRequest.MaxTokens)
}
modelRatio = common.GetModelRatio(textRequest.Model)
ratio = modelRatio * groupRatio
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
} else {
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@@ -280,6 +312,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeGemini:
geminiChatRequest := requestOpenAI2Gemini(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
@@ -375,13 +414,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
case APITypeTencent:
req.Header.Set("Authorization", apiKey)
case APITypeGemini:
req.Header.Set("Content-Type", "application/json")
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
if apiType != APITypeGemini {
// 设置公共头部...
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
//req.HeaderBar.Set("Connection", c.Request.HeaderBar.Get("Connection"))
resp, err = httpClient.Do(req)
@@ -418,15 +462,19 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) {
// c.Writer.Flush()
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
quota := 0
if modelPrice == -1 {
completionRatio := common.GetCompletionRatio(textRequest.Model)
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
} else {
quota = int(modelPrice * common.QuotaPerUnit * groupRatio)
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
@@ -443,10 +491,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
// record all the consume log even if quota is 0
useTimeSeconds := time.Now().Unix() - startTime.Unix()
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId, userQuota)
var logContent string
if modelPrice == -1 {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,用时 %d秒", modelRatio, groupRatio, useTimeSeconds)
} else {
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f,用时 %d秒", modelPrice, groupRatio, useTimeSeconds)
}
logModel := textRequest.Model
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf(",模型 %s", textRequest.Model)
}
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, tokenId, userQuota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
//if quota != 0 {
@@ -539,6 +599,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
return nil
}
case APITypeGemini:
if textRequest.Stream {
err, responseText := geminiChatStreamHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage.PromptTokens = promptTokens
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil
} else {
err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)

View File

@@ -76,12 +76,13 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
}
var config image.Config
var err error
var format string
if strings.HasPrefix(imageUrl.Url, "http") {
common.SysLog(fmt.Sprintf("downloading image: %s", imageUrl.Url))
config, err = common.DecodeUrlImageData(imageUrl.Url)
config, format, err = common.DecodeUrlImageData(imageUrl.Url)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, err = common.DecodeBase64ImageData(imageUrl.Url)
config, format, err = common.DecodeBase64ImageData(imageUrl.Url)
}
if err != nil {
return 0, err
@@ -101,7 +102,7 @@ func getImageToken(imageUrl *MessageImageUrl) (int, error) {
shortSide := config.Width
otherSide := config.Height
log.Printf("width: %d, height: %d", config.Width, config.Height)
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
// 缩放倍数
scale := 1.0
if config.Height < shortSide {
@@ -160,10 +161,27 @@ func countTokenMessages(messages []Message, model string) (int, error) {
} else {
for _, m := range arrayContent {
if m.Type == "image_url" {
imageTokenNum, err := getImageToken(&m.ImageUrl)
var imageTokenNum int
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else {
@@ -177,12 +195,12 @@ func countTokenMessages(messages []Message, model string) (int, error) {
}
func countTokenInput(input any, model string) int {
switch input.(type) {
switch v := input.(type) {
case string:
return countTokenText(input.(string), model)
return countTokenText(v, model)
case []string:
text := ""
for _, s := range input.([]string) {
for _, s := range v {
text += s
}
return countTokenText(text, model)

View File

@@ -33,7 +33,7 @@ type XunfeiChatRequest struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`

View File

@@ -19,9 +19,9 @@ type Message struct {
}
type MediaMessage struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl MessageImageUrl `json:"image_url,omitempty"`
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
}
type MessageImageUrl struct {
@@ -29,6 +29,60 @@ type MessageImageUrl struct {
Detail string `json:"detail"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: stringContent,
})
return contentList
}
var arrayContent []json.RawMessage
if err := json.Unmarshal(m.Content, &arrayContent); err == nil {
for _, contentItem := range arrayContent {
var contentMap map[string]any
if err := json.Unmarshal(contentItem, &contentMap); err != nil {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
detail, ok := subObj["detail"]
if ok {
subObj["detail"] = detail.(string)
} else {
subObj["detail"] = "auto"
}
contentList = append(contentList, MediaMessage{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
Detail: subObj["detail"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
@@ -53,7 +107,7 @@ type GeneralOpenAIRequest struct {
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
@@ -91,14 +145,14 @@ type AudioRequest struct {
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
MaxTokens uint `json:"max_tokens"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
MaxTokens uint `json:"max_tokens"`
//Stream bool `json:"stream"`
}

View File

@@ -152,6 +152,21 @@ func Register(c *gin.Context) {
return
}
}
exist, err := model.CheckUserExistOrDeleted(user.Username, user.Email)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if exist {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户名已存在,或已注销",
})
return
}
affCode := user.AffCode // this code is the inviter's code, not the user's own code
inviterId, _ := model.GetUserIdByAffCode(affCode)
cleanUser := model.User{
@@ -525,7 +540,7 @@ func DeleteUser(c *gin.Context) {
})
return
}
err = model.DeleteUserById(id)
err = model.HardDeleteUserById(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -632,7 +647,7 @@ func ManageUser(c *gin.Context) {
Username: req.Username,
}
// Fill attributes
model.DB.Where(&user).First(&user)
model.DB.Unscoped().Where(&user).First(&user)
if user.Id == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -668,7 +683,7 @@ func ManageUser(c *gin.Context) {
})
return
}
if err := user.Delete(); err != nil {
if err := user.HardDelete(); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),

View File

@@ -1,9 +1,9 @@
version: '3.4'
services:
one-api:
new-api:
image: calciumion/new-api:latest
container_name: one-api
container_name: new-api
restart: always
command: --log-dir /app/logs
ports:
@@ -12,7 +12,7 @@ services:
- ./data:/data
- ./logs:/app/logs
environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/new-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai

12
go.mod
View File

@@ -19,7 +19,7 @@ require (
github.com/samber/lo v1.38.1
github.com/shirou/gopsutil v3.21.11+incompatible
github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2
golang.org/x/crypto v0.14.0
golang.org/x/crypto v0.17.0
gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
@@ -44,13 +44,14 @@ require (
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jackc/pgx/v5 v5.5.1 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -64,8 +65,9 @@ require (
golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

14
go.sum
View File

@@ -81,6 +81,10 @@ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
@@ -108,6 +112,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
@@ -172,11 +178,15 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM=
golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -189,12 +199,16 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=

View File

@@ -27,7 +27,7 @@ var indexPage []byte
func main() {
common.SetupLogger()
common.SysLog("One API " + common.Version + " started")
common.SysLog("New API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode)
}
@@ -81,7 +81,7 @@ func main() {
}
go controller.AutomaticallyTestChannels(frequency)
}
go controller.UpdateMidjourneyTask()
go controller.UpdateMidjourneyTaskBulk()
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
common.BatchUpdateEnabled = true
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

View File

@@ -91,11 +91,11 @@ func TokenAuth() func(c *gin.Context) {
key = c.Request.Header.Get("mj-api-secret")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
parts = strings.Split(key, "-")
key = parts[0]
} else {
key = strings.TrimPrefix(key, "sk-")
parts := strings.Split(key, "-")
parts = strings.Split(key, "-")
key = parts[0]
}
token, err := model.ValidateUserToken(key)

View File

@@ -107,6 +107,8 @@ func Distribute() func(c *gin.Context) {
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
}
c.Next()
}

View File

@@ -6,11 +6,12 @@ import (
)
type Ability struct {
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Group string `json:"group" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
Weight uint `json:"weight" gorm:"default:0;index"`
}
func GetGroupModels(group string) []string {
@@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
var abilities []Ability
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
@@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error
err = channelQuery.Order("weight DESC").Find(&abilities).Error
}
if err != nil {
return nil, err
}
channel := Channel{}
channel.Id = ability.ChannelId
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
if len(abilities) > 0 {
// Randomly choose one
weightSum := uint(0)
for _, ability_ := range abilities {
weightSum += ability_.Weight
}
if weightSum == 0 {
// All weight is 0, randomly choose one
channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
} else {
// Randomly choose one
weight := common.GetRandomInt(int(weightSum))
for _, ability_ := range abilities {
weight -= int(ability_.Weight)
//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
if weight <= 0 {
channel.Id = ability_.ChannelId
break
}
}
}
} else {
return nil, nil
}
err = DB.First(&channel, "id = ?", channel.Id).Error
return &channel, err
}
@@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
}
abilities = append(abilities, ability)
}

View File

@@ -133,6 +133,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
}
var group2model2channels map[string]map[string][]*Channel
var channelsIDM map[int]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@@ -149,10 +150,12 @@ func InitChannelCache() {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
newChannelsIDM := make(map[int]*Channel)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
}
for _, channel := range channels {
newChannelsIDM[channel.Id] = channel
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
@@ -177,6 +180,7 @@ func InitChannelCache() {
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
channelsIDM = newChannelsIDM
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@@ -194,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
model = "gpt-4-gizmo-*"
}
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
}
@@ -214,6 +219,41 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
// Calculate the total weight of all channels up to endIdx
totalWeight := 0
for _, channel := range channels[:endIdx] {
totalWeight += channel.GetWeight()
}
if totalWeight == 0 {
// If all weights are 0, select a channel randomly
return channels[rand.Intn(endIdx)], nil
}
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range channels[:endIdx] {
randomWeight -= channel.GetWeight()
if randomWeight <= 0 {
return channel, nil
}
}
// return the last channel if no channel is found
return channels[endIdx-1], nil
}
func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, errors.New(fmt.Sprintf("当前渠道# %d已不存在", id))
}
return c, nil
}

View File

@@ -21,7 +21,7 @@ type Channel struct {
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
Group string `json:"group" gorm:"type:varchar(255);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
@@ -86,6 +86,26 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func BatchDeleteChannels(ids []int) error {
//使用事务 删除channel表和channel_ability表
tx := DB.Begin()
err := tx.Where("id in (?)", ids).Delete(&Channel{}).Error
if err != nil {
// 回滚事务
tx.Rollback()
return err
}
err = tx.Where("channel_id in (?)", ids).Delete(&Ability{}).Error
if err != nil {
// 回滚事务
tx.Rollback()
return err
}
// 提交事务
tx.Commit()
return err
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
@@ -93,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
return *channel.Priority
}
func (channel *Channel) GetWeight() int {
if channel.Weight == nil {
return 0
}
return int(*channel.Weight)
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""

View File

@@ -52,6 +52,14 @@ func chooseDB() (*gorm.DB, error) {
}
// Use MySQL
common.SysLog("using MySQL as database")
// check parseTime
if !strings.Contains(dsn, "parseTime") {
if strings.Contains(dsn, "?") {
dsn += "&parseTime=true"
} else {
dsn += "?parseTime=true"
}
}
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})

View File

@@ -131,3 +131,9 @@ func (midjourney *Midjourney) Update() error {
err = DB.Save(midjourney).Error
return err
}
func MjBulkUpdate(taskIDs []string, params map[string]any) error {
return DB.Model(&Midjourney{}).
Where("mj_id in (?)", taskIDs).
Updates(params).Error
}

View File

@@ -70,6 +70,7 @@ func InitOptionMap() {
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -220,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) {
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "ModelPrice":
err = common.UpdateModelPriceByJSONString(value)
case "TopUpLink":
common.TopUpLink = value
case "ChatLink":

View File

@@ -45,7 +45,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
token, err = CacheGetTokenByKey(key)
if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
return nil, errors.New("该令牌额度已用尽 token.Status == common.TokenStatusExhausted " + key)
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
@@ -71,7 +71,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌额度已用尽")
return nil, errors.New(fmt.Sprintf("%s 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", token.Key, token.RemainQuota))
}
return token, nil
}

View File

@@ -11,26 +11,51 @@ import (
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"`
AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度
AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
var user User
// err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
// check email if empty
var err error
if email == "" {
err = DB.Unscoped().First(&user, "username = ?", username).Error
} else {
err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error
}
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// not exist, return false, nil
return false, nil
}
// other error, return false, err
return false, err
}
// exist, return true, nil
return true, nil
}
func GetMaxUserId() int {
@@ -40,7 +65,7 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
return users, err
}
@@ -80,6 +105,14 @@ func DeleteUserById(id int) (err error) {
return user.Delete()
}
func HardDeleteUserById(id int) error {
if id == 0 {
return errors.New("id 为空!")
}
err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error
return err
}
func inviteUser(inviterId int) (err error) {
user, err := GetUserById(inviterId, true)
if err != nil {
@@ -183,6 +216,14 @@ func (user *User) Delete() error {
return err
}
func (user *User) HardDelete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Unscoped().Delete(user).Error
return err
}
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,

View File

@@ -83,6 +83,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel)
channelRoute.POST("/batch", controller.DeleteChannelBatch)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -10,6 +10,7 @@
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.2.3",
"react-fireworks": "^1.0.4",
"react-router-dom": "^6.3.0",
"react-scripts": "5.0.1",
"react-toastify": "^9.0.8",

View File

@@ -74,6 +74,11 @@ function renderBalance(type, balance) {
const ChannelsTable = () => {
const columns = [
// {
// title: '',
// dataIndex: 'checkbox',
// className: 'checkbox',
// },
{
title: 'ID',
dataIndex: 'id',
@@ -158,11 +163,30 @@ const ChannelsTable = () => {
<div>
<InputNumber
style={{width: 70}}
name='name'
name='priority'
onChange={value => {
manageChannel(record.id, 'priority', record, value);
}}
defaultValue={record.priority}
min={-999}
/>
</div>
);
},
},
{
title: '权重',
dataIndex: 'weight',
render: (text, record, index) => {
return (
<div>
<InputNumber
style={{width: 70}}
name='weight'
onChange={value => {
manageChannel(record.id, 'weight', record, value);
}}
defaultValue={record.weight}
min={0}
/>
</div>
@@ -235,9 +259,11 @@ const ChannelsTable = () => {
const [channelCount, setChannelCount] = useState(pageSize);
const [groupOptions, setGroupOptions] = useState([]);
const [showEdit, setShowEdit] = useState(false);
const [enableBatchDelete, setEnableBatchDelete] = useState(false);
const [editingChannel, setEditingChannel] = useState({
id: undefined,
});
const [selectedChannels, setSelectedChannels] = useState([]);
const removeRecord = id => {
let newDataSource = [...channels];
@@ -484,6 +510,27 @@ const ChannelsTable = () => {
setUpdatingBalance(false);
};
const batchDeleteChannels = async () => {
if (selectedChannels.length === 0) {
showError('请先选择要删除的通道!');
return;
}
setLoading(true);
let ids = [];
selectedChannels.forEach((channel) => {
ids.push(channel.id);
});
const res = await API.post(`/api/channel/batch`, {ids: ids});
const {success, message, data} = res.data;
if (success) {
showSuccess(`已删除 ${data} 个通道!`);
await refresh();
} else {
showError(message);
}
setLoading(false);
}
const sortChannel = (key) => {
if (channels.length === 0) return;
setLoading(true);
@@ -557,6 +604,7 @@ const ChannelsTable = () => {
}
};
return (
<>
<EditChannel refresh={refresh} visible={showEdit} handleClose={closeEdit} editingChannel={editingChannel}/>
@@ -583,16 +631,18 @@ const ChannelsTable = () => {
</Form>
<div style={{marginTop: 10, display: 'flex'}}>
<Space>
<Typography.Text strong>使用ID排序</Typography.Text>
<Switch checked={idSort} label='使用ID排序' uncheckedText="关" aria-label="是否用ID排序" onChange={(v) => {
localStorage.setItem('id-sort', v + '')
setIdSort(v)
loadChannels(0, pageSize, v)
.then()
.catch((reason) => {
showError(reason);
})
}}></Switch>
<Space>
<Typography.Text strong>使用ID排序</Typography.Text>
<Switch checked={idSort} label='使用ID排序' uncheckedText="关" aria-label="是否用ID排序" onChange={(v) => {
localStorage.setItem('id-sort', v + '')
setIdSort(v)
loadChannels(0, pageSize, v)
.then()
.catch((reason) => {
showError(reason);
})
}}></Switch>
</Space>
</Space>
</div>
@@ -607,7 +657,15 @@ const ChannelsTable = () => {
handlePageSizeChange(size).then()
},
onPageChange: handlePageChange,
}} loading={loading} onRow={handleRow}/>
}} loading={loading} onRow={handleRow} rowSelection={
enableBatchDelete ?
{
onChange: (selectedRowKeys, selectedRows) => {
// console.log(`selectedRowKeys: ${selectedRowKeys}`, 'selectedRows: ', selectedRows);
setSelectedChannels(selectedRows);
},
} : null
}/>
<div style={{display: isMobile()?'':'flex', marginTop: isMobile()?0:-45, zIndex: 999, position: 'relative', pointerEvents: 'none'}}>
<Space style={{pointerEvents: 'auto'}}>
<Button theme='light' type='primary' style={{marginRight: 8}} onClick={
@@ -622,7 +680,7 @@ const ChannelsTable = () => {
title="确定?"
okType={'warning'}
onConfirm={testAllChannels}
position={isMobile()?'top':''}
position={isMobile()?'top':'top'}
>
<Button theme='light' type='warning' style={{marginRight: 8}}>测试所有已启用通道</Button>
</Popconfirm>
@@ -648,6 +706,24 @@ const ChannelsTable = () => {
{/*</div>*/}
</div>
<div style={{marginTop: 20}}>
<Space>
<Typography.Text strong>开启批量删除</Typography.Text>
<Switch label='开启批量删除' uncheckedText="关" aria-label="是否开启批量删除" onChange={(v) => {
setEnableBatchDelete(v)
}}></Switch>
<Popconfirm
title="确定是否要删除所选通道?"
content="此修改将不可逆"
okType={'danger'}
onConfirm={batchDeleteChannels}
disabled={!enableBatchDelete}
position={'top'}
>
<Button disabled={!enableBatchDelete} theme='light' type='danger' style={{marginRight: 8}}>删除所选通道</Button>
</Popconfirm>
</Space>
</div>
</>
);
};

View File

@@ -1,4 +1,4 @@
import React, {useContext, useEffect, useState} from 'react';
import React, {useContext, useEffect, useRef, useState} from 'react';
import {Link, useNavigate} from 'react-router-dom';
import {UserContext} from '../context/User';
@@ -6,18 +6,12 @@ import {Button, Container, Icon, Menu, Segment} from 'semantic-ui-react';
import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers';
import '../index.css';
import fireworks from 'react-fireworks';
import {
IconAt,
IconHistogram,
IconGift,
IconKey,
IconUser,
IconLayers,
IconHelpCircle,
IconCreditCard,
IconSemiLogo,
IconHome,
IconImage
IconHelpCircle
} from '@douyinfe/semi-icons';
import {Nav, Avatar, Dropdown, Layout, Switch} from '@douyinfe/semi-ui';
import {stringToColor} from "../helpers/render";
@@ -49,6 +43,8 @@ const HeaderBar = () => {
const systemName = getSystemName();
const logo = getLogo();
var themeMode = localStorage.getItem('theme-mode');
const currentDate = new Date();
const isNewYear = currentDate.getMonth() === 0 && currentDate.getDate() === 1;
async function logout() {
setShowSidebar(false);
@@ -59,10 +55,24 @@ const HeaderBar = () => {
navigate('/login');
}
const handleNewYearClick = () => {
fireworks.init("root",{});
fireworks.start();
setTimeout(() => {
fireworks.stop();
setTimeout(() => {
window.location.reload();
}, 10000);
}, 3000);
};
useEffect(() => {
if (themeMode === 'dark') {
switchMode(true);
}
if (isNewYear) {
console.log('Happy New Year!');
}
}, []);
const switchMode = (model) => {
@@ -105,6 +115,19 @@ const HeaderBar = () => {
}}
footer={
<>
{isNewYear &&
// happy new year
<Dropdown
position="bottomRight"
render={
<Dropdown.Menu>
<Dropdown.Item onClick={handleNewYearClick}>Happy New Year!!!</Dropdown.Item>
</Dropdown.Menu>
}
>
<Nav.Item itemKey={'new-year'} text={'🏮'}/>
</Dropdown>
}
<Nav.Item itemKey={'about'} icon={<IconHelpCircle />} />
<Switch checkedText="🌞" size={'large'} checked={dark} uncheckedText="🌙" onChange={switchMode} />
{userState.user ?

View File

@@ -2,7 +2,19 @@ import React, {useEffect, useState} from 'react';
import {Label} from 'semantic-ui-react';
import {API, copy, isAdmin, showError, showSuccess, timestamp2string} from '../helpers';
import {Table, Avatar, Tag, Form, Button, Layout, Select, Popover, Modal } from '@douyinfe/semi-ui';
import {
Table,
Avatar,
Tag,
Form,
Button,
Layout,
Select,
Popover,
Modal,
ImagePreview,
Typography
} from '@douyinfe/semi-ui';
import {ITEMS_PER_PAGE} from '../constants';
import {renderNumber, renderQuota, stringToColor} from '../helpers/render';
@@ -194,19 +206,16 @@ const LogsTable = () => {
}
return (
text.length > 10 ?
<>
{text.slice(0, 10)}
<Button
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
查看全部
</Button>
</>
: text
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
@@ -220,19 +229,16 @@ const LogsTable = () => {
}
return (
text.length > 10 ?
<>
{text.slice(0, 10)}
<Button
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
查看全部
</Button>
</>
: text
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
},
@@ -246,19 +252,16 @@ const LogsTable = () => {
}
return (
text.length > 10 ?
<>
{text.slice(0, 10)}
<Button
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
查看全部
</Button>
</>
: text
<Typography.Text
ellipsis={{ showTooltip: true }}
style={{ width: 100 }}
onClick={() => {
setModalContent(text);
setIsModalOpen(true);
}}
>
{text}
</Typography.Text>
);
}
}
@@ -414,15 +417,11 @@ const LogsTable = () => {
>
<p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
</Modal>
{/* 模态框组件,用于展示图片 */}
<Modal
title="图片预览"
visible={isModalOpenurl}
onCancel={() => setIsModalOpenurl(false)}
footer={null} // 模态框不显示底部按钮
>
<img src={modalImageUrl} style={{ width: '100%' }} alt="结果图片" />
</Modal>
<ImagePreview
src={modalImageUrl}
visible={isModalOpenurl}
onVisibleChange={(visible) => setIsModalOpenurl(visible)}
/>
</Layout>
</>

View File

@@ -10,6 +10,7 @@ const OperationSetting = () => {
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
ModelPrice: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
@@ -30,7 +31,7 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio'|| item.key === 'ModelPrice') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
@@ -97,6 +98,13 @@ const OperationSetting = () => {
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['ModelPrice'] !== inputs.ModelPrice) {
if (!verifyJSON(inputs.ModelPrice)) {
showError('模型固定价格不是合法的 JSON 字符串');
return;
}
await updateOption('ModelPrice', inputs.ModelPrice);
}
break;
case 'quota':
if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
@@ -315,6 +323,17 @@ const OperationSetting = () => {
<Header as='h3'>
倍率设置
</Header>
<Form.Group widths='equal'>
<Form.TextArea
label='模型固定价格(一次调用消耗多少刀,优先级大于模型倍率)'
name='ModelPrice'
onChange={handleInputChange}
style={{minHeight: 250, fontFamily: 'JetBrains Mono, Consolas'}}
autoComplete='new-password'
value={inputs.ModelPrice}
placeholder='为一个 JSON 文本,键为模型名称,值为一次调用消耗多少刀,比如 "gpt-4-gizmo-*": 0.1一次消耗0.1刀'
/>
</Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='模型倍率'

View File

@@ -20,6 +20,7 @@ import {
import {getQuotaPerUnit, renderQuota, renderQuotaWithPrompt, stringToColor} from "../helpers/render";
import EditToken from "../pages/Token/EditToken";
import EditUser from "../pages/User/EditUser";
import passwordResetConfirm from "./PasswordResetConfirm";
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -29,9 +30,12 @@ const PersonalSetting = () => {
wechat_verification_code: '',
email_verification_code: '',
email: '',
self_account_deletion_confirmation: ''
self_account_deletion_confirmation: '',
set_new_password: '',
set_new_password_confirmation: '',
});
const [status, setStatus] = useState({});
const [showChangePasswordModal, setShowChangePasswordModal] = useState(false);
const [showWeChatBindModal, setShowWeChatBindModal] = useState(false);
const [showEmailBindModal, setShowEmailBindModal] = useState(false);
const [showAccountDeleteModal, setShowAccountDeleteModal] = useState(false);
@@ -180,6 +184,27 @@ const PersonalSetting = () => {
}
};
const changePassword = async () => {
if (inputs.set_new_password !== inputs.set_new_password_confirmation) {
showError('两次输入的密码不一致!');
return;
}
const res = await API.put(
`/api/user/self`,
{
password: inputs.set_new_password
}
);
const {success, message} = res.data;
if (success) {
showSuccess('密码修改成功!');
setShowWeChatBindModal(false);
} else {
showError(message);
}
setShowChangePasswordModal(false);
};
const transfer = async () => {
if (transferAmount < getQuotaPerUnit()) {
showError('划转金额最低为' + renderQuota(getQuotaPerUnit()));
@@ -420,6 +445,9 @@ const PersonalSetting = () => {
<Space>
<Button onClick={generateAccessToken}>生成系统访问令牌</Button>
<Button onClick={() => {
setShowChangePasswordModal(true);
}}>修改密码</Button>
<Button type={'danger'} onClick={() => {
setShowAccountDeleteModal(true);
}}>删除个人账户</Button>
</Space>
@@ -543,6 +571,39 @@ const PersonalSetting = () => {
)}
</div>
</Modal>
<Modal
onCancel={() => setShowChangePasswordModal(false)}
visible={showChangePasswordModal}
size={'small'}
centered={true}
onOk={changePassword}
>
<div style={{marginTop: 20}}>
<Input
name='set_new_password'
placeholder='新密码'
value={inputs.set_new_password}
onChange={(value)=>handleInputChange('set_new_password', value)}
/>
<Input
style={{marginTop: 20}}
name='set_new_password_confirmation'
placeholder='确认新密码'
value={inputs.set_new_password_confirmation}
onChange={(value)=>handleInputChange('set_new_password_confirmation', value)}
/>
{turnstileEnabled ? (
<Turnstile
sitekey={turnstileSiteKey}
onVerify={(token) => {
setTurnstileToken(token);
}}
/>
) : (
<></>
)}
</div>
</Modal>
</div>
</Layout.Content>

View File

@@ -72,32 +72,49 @@ const UsersTable = () => {
}, {
title: '状态', dataIndex: 'status', render: (text, record, index) => {
return (<div>
{renderStatus(text)}
{record.DeletedAt !== null? <Tag color='red'>已注销</Tag> : renderStatus(text)}
</div>);
},
}, {
title: '', dataIndex: 'operate', render: (text, record, index) => (<div>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'promote', record)
}}
>
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
</Popconfirm>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'demote', record)
}}
>
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
</Popconfirm>
{
record.DeletedAt !== null ? <></>:
<>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'promote', record)
}}
>
<Button theme='light' type='warning' style={{marginRight: 1}}>提升</Button>
</Popconfirm>
<Popconfirm
title="确定?"
okType={'warning'}
onConfirm={() => {
manageUser(record.username, 'demote', record)
}}
>
<Button theme='light' type='secondary' style={{marginRight: 1}}>降级</Button>
</Popconfirm>
{record.status === 1 ?
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'disable', record)
}}>禁用</Button> :
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'enable', record);
}} disabled={record.status === 3}>启用</Button>}
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
setEditingUser(record);
setShowEditUser(true);
}}>编辑</Button>
</>
}
<Popconfirm
title="确定是否要删除此用户?"
content="此修改将不可逆"
content="硬删除,此修改将不可逆"
okType={'danger'}
position={'left'}
onConfirm={() => {
@@ -108,17 +125,6 @@ const UsersTable = () => {
>
<Button theme='light' type='danger' style={{marginRight: 1}}>删除</Button>
</Popconfirm>
{record.status === 1 ?
<Button theme='light' type='warning' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'disable', record)
}}>禁用</Button> :
<Button theme='light' type='secondary' style={{marginRight: 1}} onClick={async () => {
manageUser(record.username, 'enable', record);
}} disabled={record.status === 3}>启用</Button>}
<Button theme='light' type='tertiary' style={{marginRight: 1}} onClick={() => {
setEditingUser(record);
setShowEditUser(true);
}}>编辑</Button>
</div>),
},];

View File

@@ -1,16 +1,17 @@
export const CHANNEL_OPTIONS = [
{ key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI' },
{ key: 24, text: 'Midjourney Proxy', value: 24, color: 'light-blue', label: 'Midjourney Proxy' },
{ key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude' },
{ key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI' },
{ key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2' },
{ key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆' },
{ key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问' },
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知' },
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM' },
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue', label: '知识库FastGPT' },
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple', label: '知识库:AI Proxy' },
{key: 1, text: 'OpenAI', value: 1, color: 'green', label: 'OpenAI'},
{key: 2, text: 'Midjourney Proxy', value: 2, color: 'light-blue', label: 'Midjourney Proxy'},
{key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude'},
{key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI'},
{key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2'},
{key: 24, text: 'Google Gemini', value: 24, color: 'orange', label: 'Google Gemini'},
{key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆'},
{key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问'},
{key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知'},
{key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet', label: '智谱 ChatGLM'},
{key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑'},
{key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元'},
{key: 8, text: '自定义渠道', value: 8, color: 'pink', label: '自定义渠道'},
{key: 22, text: '知识库:FastGPT', value: 22, color: 'blue', label: '知识库:FastGPT'},
{key: 21, text: '知识库AI Proxy', value: 21, color: 'purple', label: '知识库AI Proxy'},
];

View File

@@ -86,6 +86,9 @@ const EditChannel = (props) => {
case 23:
localModels = ['hunyuan'];
break;
case 24:
localModels = ['gemini-pro'];
break;
}
setInputs((inputs) => ({...inputs, models: localModels}));
}

View File

@@ -74,12 +74,17 @@ const TopUp = () => {
const {message, data} = res.data;
// showInfo(message);
if (message === 'success') {
let params = data
let url = res.data.url
let form = document.createElement('form')
form.action = url
form.method = 'POST'
form.target = '_blank'
// 判断是否为safari浏览器
let isSafari = navigator.userAgent.indexOf("Safari") > -1 && navigator.userAgent.indexOf("Chrome") < 1;
if (!isSafari) {
form.target = '_blank'
}
for (let key in params) {
let input = document.createElement('input')
input.type = 'hidden'