From 606aa8a4a735586346cb9cab9e34343c60f5cc5d Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Thu, 13 Jun 2024 00:32:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=B0=83=E8=AF=95=20suno?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 4 +- controller/channel-test.go | 3 + controller/misc.go | 1 + controller/task.go | 268 +++++++++++++++-- main.go | 3 + middleware/distributor.go | 2 +- model/main.go | 4 + model/option.go | 3 + relay/channel/task/suno/adaptor.go | 9 +- router/api-router.go | 6 + web/src/App.js | 11 + web/src/components/SiderBar.js | 14 +- web/src/components/TaskLogsTable.js | 400 +++++++++++++++++++++++++ web/src/constants/channel.constants.js | 7 + web/src/helpers/data.js | 1 + web/src/pages/Channel/EditChannel.js | 33 +- web/src/pages/Task/index.js | 10 + 17 files changed, 737 insertions(+), 42 deletions(-) create mode 100644 web/src/components/TaskLogsTable.js create mode 100644 web/src/pages/Task/index.js diff --git a/common/constants.go b/common/constants.go index b8adfdbcc..1bca1418e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -21,6 +21,7 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true var DrawingEnabled = true +var TaskEnabled = true var DataExportEnabled = true var DataExportInterval = 5 // unit: minute var DataExportDefaultTime = "hour" // unit: minute @@ -208,7 +209,7 @@ const ( ChannelTypeAws = 33 ChannelTypeCohere = 34 ChannelTypeMiniMax = 35 - ChannelTypeSuno = 36 + ChannelTypeSunoAPI = 36 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -251,4 +252,5 @@ var ChannelBaseURLs = []string{ "", //33 "https://api.cohere.ai", //34 "https://api.minimax.chat", //35 + "", //36 } diff --git a/controller/channel-test.go b/controller/channel-test.go index db03e7529..0b8d442f5 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -27,6 +27,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } + if channel.Type == common.ChannelTypeSunoAPI { + return errors.New("suno channel test is not supported"), nil + } w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ diff --git a/controller/misc.go b/controller/misc.go index b8203f3a3..5e12854b3 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -57,6 +57,7 @@ func GetStatus(c *gin.Context) { "display_in_currency": common.DisplayInCurrencyEnabled, "enable_batch_update": common.BatchUpdateEnabled, "enable_drawing": common.DrawingEnabled, + "enable_task": common.TaskEnabled, "enable_data_export": common.DataExportEnabled, "data_export_default_time": common.DataExportDefaultTime, "default_collapse_sidebar": common.DefaultCollapseSidebar, diff --git a/controller/task.go b/controller/task.go index 7b7d0223d..02a7933db 100644 --- a/controller/task.go +++ b/controller/task.go @@ -1,11 +1,21 @@ package controller import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" "github.com/gin-gonic/gin" - "log" + "github.com/samber/lo" + "io" + "net/http" "one-api/common" "one-api/constant" + "one-api/dto" "one-api/model" + "one-api/service" + "sort" "strconv" "time" ) @@ -16,42 +26,238 @@ func UpdateTaskBulk() { for { time.Sleep(time.Duration(15) * time.Second) common.SysLog("任务进度轮询开始") + ctx := context.TODO() allTasks := model.GetAllUnFinishSyncTasks(500) platformTask := make(map[constant.TaskPlatform][]*model.Task) for _, t := range allTasks { platformTask[t.Platform] = append(platformTask[t.Platform], t) } for platform, tasks := range platformTask { - UpdateTaskByPlatform(platform, tasks) + if len(tasks) == 0 { + continue + } + taskChannelM := make(map[int][]string) + taskM := make(map[string]*model.Task) + nullTaskIds := make([]int64, 0) + for _, task := range tasks { + if task.TaskID == "" { + // 统计失败的未完成任务 + nullTaskIds = append(nullTaskIds, task.ID) + continue + } + taskM[task.TaskID] = task + taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID) + } + if len(nullTaskIds) > 0 { + err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) + } else { + common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) + } + } + if len(taskChannelM) == 0 { + continue + } + + UpdateTaskByPlatform(platform, taskChannelM, taskM) } common.SysLog("任务进度轮询完成") } } -func GetAllMidjourney(c *gin.Context) { +func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { + switch platform { + case constant.TaskPlatformMidjourney: + //_ = UpdateMidjourneyTaskAll(context.Background(), tasks) + case constant.TaskPlatformSuno: + _ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM) + default: + common.SysLog("未知平台") + } +} + +func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { + for channelId, taskIds := range taskChannelM { + err := updateSunoTaskAll(ctx, channelId, taskIds, taskM) + if err != nil { + common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error())) + } + } + return nil +} + +func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { + common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) + if len(taskIds) == 0 { + return nil + } + channel, err := model.CacheGetChannel(channelId) + if err != nil { + common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) + err = model.TaskBulkUpdate(taskIds, map[string]any{ + "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), + "status": "FAILURE", + "progress": "100%", + }) + if err != nil { + common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err)) + } + return err + } + requestUrl := fmt.Sprintf("%s/fetch", *channel.BaseURL) + + body, _ := json.Marshal(map[string]any{ + "ids": taskIds, + }) + req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body)) + if err != nil { + common.SysError(fmt.Sprintf("Get Task error: %v", err)) + return err + } + defer req.Body.Close() + // 设置超时时间 + timeout := time.Second * 15 + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + // 使用带有超时的 context 创建新的请求 + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+channel.Key) + resp, err := service.GetHttpClient().Do(req) + if err != nil { + common.SysError(fmt.Sprintf("Get Task Do req error: %v", err)) + return err + } + if resp.StatusCode != http.StatusOK { + common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) + } + defer resp.Body.Close() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError(fmt.Sprintf("Get Task parse body error: %v", err)) + return err + } + var responseItems dto.TaskResponse[[]dto.SunoDataResponse] + err = json.Unmarshal(responseBody, &responseItems) + if err != nil { + common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, req: %s, body: %s", err, string(body), string(responseBody))) + return err + } + if !responseItems.IsSuccess() { + common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody))) + return err + } + + for _, responseItem := range responseItems.Data { + task := taskM[responseItem.TaskID] + if !checkTaskNeedUpdate(task, responseItem) { + continue + } + + task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) + task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) + task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) + task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) + task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) + if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { + common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) + task.Progress = "100%" + err = model.CacheUpdateUserQuota(task.UserId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } else { + quota := task.Quota + if quota != 0 { + err = model.IncreaseUserQuota(task.UserId, quota) + if err != nil { + common.LogError(ctx, "fail to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } + } + if responseItem.Status == model.TaskStatusSuccess { + task.Progress = "100%" + } + task.Data = responseItem.Data + + err = task.Update() + if err != nil { + common.SysError("UpdateMidjourneyTask task error: " + err.Error()) + } + } + return nil +} + +func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { + + if oldTask.SubmitTime != newTask.SubmitTime { + return true + } + if oldTask.StartTime != newTask.StartTime { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + if string(oldTask.Status) != newTask.Status { + return true + } + if oldTask.FailReason != newTask.FailReason { + return true + } + if oldTask.FinishTime != newTask.FinishTime { + return true + } + + if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { + return true + } + + oldData, _ := json.Marshal(oldTask.Data) + newData, _ := json.Marshal(newTask.Data) + + sort.Slice(oldData, func(i, j int) bool { + return oldData[i] < oldData[j] + }) + sort.Slice(newData, func(i, j int) bool { + return newData[i] < newData[j] + }) + + if string(oldData) != string(newData) { + return true + } + return false +} + +func GetAllTask(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } - + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) // 解析其他查询参数 - queryParams := model.TaskQueryParams{ - ChannelID: c.Query("channel_id"), - MjID: c.Query("mj_id"), - StartTimestamp: c.Query("start_timestamp"), - EndTimestamp: c.Query("end_timestamp"), + queryParams := model.SyncTaskQueryParams{ + Platform: constant.TaskPlatform(c.Query("platform")), + TaskID: c.Query("task_id"), + Status: c.Query("status"), + Action: c.Query("action"), + StartTimestamp: startTimestamp, + EndTimestamp: endTimestamp, } - logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) + logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) if logs == nil { - logs = make([]*model.Midjourney, 0) - } - if constant.MjForwardUrlEnabled { - for i, midjourney := range logs { - midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId - logs[i] = midjourney - } + logs = make([]*model.Task, 0) } + c.JSON(200, gin.H{ "success": true, "message": "", @@ -59,31 +265,31 @@ func GetAllMidjourney(c *gin.Context) { }) } -func GetUserMidjourney(c *gin.Context) { +func GetUserTask(c *gin.Context) { p, _ := strconv.Atoi(c.Query("p")) if p < 0 { p = 0 } userId := c.GetInt("id") - log.Printf("userId = %d \n", userId) - queryParams := model.TaskQueryParams{ - MjID: c.Query("mj_id"), - StartTimestamp: c.Query("start_timestamp"), - EndTimestamp: c.Query("end_timestamp"), + startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) + endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) + + queryParams := model.SyncTaskQueryParams{ + Platform: constant.TaskPlatform(c.Query("platform")), + TaskID: c.Query("task_id"), + Status: c.Query("status"), + Action: c.Query("action"), + StartTimestamp: startTimestamp, + EndTimestamp: endTimestamp, } - logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) + logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) if logs == nil { - logs = make([]*model.Midjourney, 0) - } - if constant.MjForwardUrlEnabled { - for i, midjourney := range logs { - midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId - logs[i] = midjourney - } + logs = make([]*model.Task, 0) } + c.JSON(200, gin.H{ "success": true, "message": "", diff --git a/main.go b/main.go index 070fd1d10..006c11803 100644 --- a/main.go +++ b/main.go @@ -92,6 +92,9 @@ func main() { common.SafeGoroutine(func() { controller.UpdateMidjourneyTaskBulk() }) + common.SafeGoroutine(func() { + controller.UpdateTaskBulk() + }) if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { common.BatchUpdateEnabled = true common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") diff --git a/middleware/distributor.go b/middleware/distributor.go index 94079d39b..4862a4845 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -134,7 +134,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action")) modelRequest.Model = modelName } - c.Set("platform", constant.TaskPlatformSuno) + c.Set("platform", string(constant.TaskPlatformSuno)) c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) diff --git a/model/main.go b/model/main.go index b6ad2cb31..710ea0592 100644 --- a/model/main.go +++ b/model/main.go @@ -140,6 +140,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&Task{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err diff --git a/model/option.go b/model/option.go index 6aa59cb57..45aa52408 100644 --- a/model/option.go +++ b/model/option.go @@ -41,6 +41,7 @@ func InitOptionMap() { common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled) + common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled) common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled) common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) @@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) { common.DisplayTokenStatEnabled = boolValue case "DrawingEnabled": common.DrawingEnabled = boolValue + case "TaskEnabled": + common.TaskEnabled = boolValue case "DataExportEnabled": common.DataExportEnabled = boolValue case "DefaultCollapseSidebar": diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index ff7261183..2c6d64d58 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -18,12 +18,10 @@ import ( type TaskAdaptor struct { ChannelType int - Action string } func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { a.ChannelType = info.ChannelType - } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { @@ -49,16 +47,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom info.OriginTaskID = sunoRequest.TaskID } - a.Action = info.Action + info.Action = action c.Set("task_request", sunoRequest) return nil } func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { - baseURL := common.ChannelBaseURLs[info.ChannelType] - if info.BaseUrl != "" { - baseURL = info.BaseUrl - } + baseURL := info.BaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/submit/"+info.Action) return fullRequestURL, nil } diff --git a/router/api-router.go b/router/api-router.go index 7657a98a9..68079396a 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -140,5 +140,11 @@ func SetApiRouter(router *gin.Engine) { mjRoute := apiRouter.Group("/mj") mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney) mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney) + + taskRoute := apiRouter.Group("/task") + { + taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask) + taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask) + } } } diff --git a/web/src/App.js b/web/src/App.js index 1b63def8b..0db9a2289 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -23,6 +23,7 @@ import Chat from './pages/Chat'; import { Layout } from '@douyinfe/semi-ui'; import Midjourney from './pages/Midjourney'; import Pricing from './pages/Pricing/index.js'; +import Task from "./pages/Task/index.js"; // import Detail from './pages/Detail'; const Home = lazy(() => import('./pages/Home')); @@ -220,6 +221,16 @@ function App() { } /> + + }> + + + + } + /> { chat: '/chat', detail: '/detail', pricing: '/pricing', + task: '/task', }; const headerButtons = useMemo( @@ -142,6 +143,16 @@ const SiderBar = () => { ? 'semi-navigation-item-normal' : 'tableHiddle', }, + { + text: '异步任务', + itemKey: 'task', + to: '/task', + icon: , + className: + localStorage.getItem('enable_task') === 'true' + ? 'semi-navigation-item-normal' + : 'tableHiddle', + }, { text: '设置', itemKey: 'setting', @@ -158,6 +169,7 @@ const SiderBar = () => { [ localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), + localStorage.getItem('enable_task'), localStorage.getItem('chat_link'), isAdmin(), ], diff --git a/web/src/components/TaskLogsTable.js b/web/src/components/TaskLogsTable.js new file mode 100644 index 000000000..52bf39bbe --- /dev/null +++ b/web/src/components/TaskLogsTable.js @@ -0,0 +1,400 @@ +import React, { useEffect, useState } from 'react'; +import { Label } from 'semantic-ui-react'; +import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers'; + +import { + Table, + Tag, + Form, + Button, + Layout, + Modal, + Typography, Progress, Card +} from '@douyinfe/semi-ui'; +import { ITEMS_PER_PAGE } from '../constants'; + +const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo', + 'light-blue', 'lime', 'orange', 'pink', + 'purple', 'red', 'teal', 'violet', 'yellow' +] + + +const renderTimestamp = (timestampInSeconds) => { + const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒 + + const year = date.getFullYear(); // 获取年份 + const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数 + const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数 + const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数 + const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数 + const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数 + + return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出 +}; + +function renderDuration(submit_time, finishTime) { + // 确保startTime和finishTime都是有效的时间戳 + if (!submit_time || !finishTime) return 'N/A'; + + // 将时间戳转换为Date对象 + const start = new Date(submit_time); + const finish = new Date(finishTime); + + // 计算时间差(毫秒) + const durationMs = finish - start; + + // 将时间差转换为秒,并保留一位小数 + const durationSec = (durationMs / 1000).toFixed(1); + + // 设置颜色:大于60秒则为红色,小于等于60秒则为绿色 + const color = durationSec > 60 ? 'red' : 'green'; + + // 返回带有样式的颜色标签 + return ( + + {durationSec} 秒 + + ); +} + +const LogsTable = () => { + const [isModalOpen, setIsModalOpen] = useState(false); + const [modalContent, setModalContent] = useState(''); + const isAdminUser = isAdmin(); + const columns = [ + { + title: "提交时间", + dataIndex: 'submit_time', + render: (text, record, index) => { + return ( +
+ {text ? renderTimestamp(text) : "-"} +
+ ); + }, + }, + { + title: "结束时间", + dataIndex: 'finish_time', + render: (text, record, index) => { + return ( +
+ {text ? renderTimestamp(text) : "-"} +
+ ); + }, + }, + { + title: '进度', + dataIndex: 'progress', + width: 50, + render: (text, record, index) => { + return ( +
+ { + // 转换例如100%为数字100,如果text未定义,返回0 + isNaN(text.replace('%', '')) ? text : + } +
+ ); + }, + }, + { + title: '花费时间', + dataIndex: 'finish_time', // 以finish_time作为dataIndex + key: 'finish_time', + render: (finish, record) => { + // 假设record.start_time是存在的,并且finish是完成时间的时间戳 + return <> + { + finish ? renderDuration(record.submit_time, finish) : "-" + } + + }, + }, + { + title: "渠道", + dataIndex: 'channel_id', + className: isAdminUser ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + return ( +
+ { + copyText(text); // 假设copyText是用于文本复制的函数 + }} + > + {' '} + {text}{' '} + +
+ ); + }, + }, + { + title: "平台", + dataIndex: 'platform', + render: (text, record, index) => { + return ( +
+ {renderPlatform(text)} +
+ ); + }, + }, + { + title: '类型', + dataIndex: 'action', + render: (text, record, index) => { + return ( +
+ {renderType(text)} +
+ ); + }, + }, + { + title: '任务ID(点击查看详情)', + dataIndex: 'task_id', + render: (text, record, index) => { + return ( { + setModalContent(JSON.stringify(record, null, 2)); + setIsModalOpen(true); + }} + > +
+ {text} +
+
); + }, + }, + { + title: '任务状态', + dataIndex: 'status', + render: (text, record, index) => { + return ( +
+ {renderStatus(text)} +
+ ); + }, + }, + + { + title: '失败原因', + dataIndex: 'fail_reason', + render: (text, record, index) => { + // 如果text未定义,返回替代文本,例如空字符串''或其他 + if (!text) { + return '无'; + } + + return ( + { + setModalContent(text); + setIsModalOpen(true); + }} + > + {text} + + ); + } + } + ]; + + const [logs, setLogs] = useState([]); + const [loading, setLoading] = useState(true); + const [activePage, setActivePage] = useState(1); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); + const [logType] = useState(0); + + let now = new Date(); + // 初始化start_timestamp为前一天 + let zeroNow = new Date(now.getFullYear(), now.getMonth(), now.getDate()); + const [inputs, setInputs] = useState({ + channel_id: '', + task_id: '', + start_timestamp: timestamp2string(zeroNow.getTime() /1000), + end_timestamp: '', + }); + const { channel_id, task_id, start_timestamp, end_timestamp } = inputs; + + const handleInputChange = (value, name) => { + setInputs((inputs) => ({ ...inputs, [name]: value })); + }; + + + const setLogsFormat = (logs) => { + for (let i = 0; i < logs.length; i++) { + logs[i].timestamp2string = timestamp2string(logs[i].created_at); + logs[i].key = '' + logs[i].id; + } + // data.key = '' + data.id + setLogs(logs); + setLogCount(logs.length + ITEMS_PER_PAGE); + // console.log(logCount); + } + + const loadLogs = async (startIdx) => { + setLoading(true); + + let url = ''; + let localStartTimestamp = parseInt(Date.parse(start_timestamp) / 1000); + let localEndTimestamp = parseInt(Date.parse(end_timestamp) / 1000 ); + if (isAdminUser) { + url = `/api/task/?p=${startIdx}&channel_id=${channel_id}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } else { + url = `/api/task/self?p=${startIdx}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + } + const res = await API.get(url); + let { success, message, data } = res.data; + if (success) { + if (startIdx === 0) { + setLogsFormat(data); + } else { + let newLogs = [...logs]; + newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); + setLogsFormat(newLogs); + } + } else { + showError(message); + } + setLoading(false); + }; + + const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE); + + const handlePageChange = page => { + setActivePage(page); + if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) { + // In this case we have to load more data and then append them. + loadLogs(page - 1).then(r => { + }); + } + }; + + const refresh = async () => { + // setLoading(true); + setActivePage(1); + await loadLogs(0); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: "无法复制到剪贴板,请手动复制", content: text }); + } + } + + useEffect(() => { + refresh().then(); + }, [logType]); + + const renderType = (type) => { + switch (type) { + case 'MUSIC': + return ; + case 'LYRICS': + return ; + + default: + return ; + } + } + + const renderPlatform = (type) => { + switch (type) { + case "suno": + return ; + default: + return ; + } + } + + const renderStatus = (type) => { + switch (type) { + case 'SUCCESS': + return ; + case 'NOT_START': + return ; + case 'SUBMITTED': + return ; + case 'IN_PROGRESS': + return ; + case 'FAILURE': + return ; + case 'QUEUED': + return ; + case 'UNKNOWN': + return ; + case '': + return ; + default: + return ; + } + } + + return ( + <> + + +
+ <> + {isAdminUser && handleInputChange(value, 'channel_id')} /> + } + handleInputChange(value, 'task_id')} /> + + handleInputChange(value, 'start_timestamp')} /> + handleInputChange(value, 'end_timestamp')} /> + + + + + + + setIsModalOpen(false)} + onCancel={() => setIsModalOpen(false)} + closable={null} + bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式 + width={800} // 设置模态框宽度 + > +

{modalContent}

+
+ + + ); +}; + +export default LogsTable; diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 9b10f0fd6..e67dbc61f 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -14,6 +14,13 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'Midjourney Proxy Plus', }, + { + key: 36, + text: 'Suno API', + value: 36, + color: 'purple', + label: 'Suno API', + }, { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' }, { key: 14, diff --git a/web/src/helpers/data.js b/web/src/helpers/data.js index 750b670f5..93380b41b 100644 --- a/web/src/helpers/data.js +++ b/web/src/helpers/data.js @@ -6,6 +6,7 @@ export function setStatusData(data) { localStorage.setItem('quota_per_unit', data.quota_per_unit); localStorage.setItem('display_in_currency', data.display_in_currency); localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_task', data.enable_task); localStorage.setItem('enable_data_export', data.enable_data_export); localStorage.setItem( 'data_export_default_time', diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 821a0563b..540e3ecf2 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -126,6 +126,12 @@ const EditChannel = (props) => { 'mj_uploads', ]; break; + case 36: + localModels = [ + 'suno_music', + 'suno_lyrics', + ]; + break; default: localModels = getChannelModels(value); break; @@ -513,6 +519,31 @@ const EditChannel = (props) => { /> )} + {inputs.type === 36 && ( + <> +
+ + Suno 非官方 API,https://github.com/Suno-API/Suno-API + + } + > +
+ { + handleInputChange('base_url', value); + }} + value={inputs.base_url} + autoComplete='new-password' + /> + + )}
名称:
@@ -758,7 +789,7 @@ const EditChannel = (props) => { )} - {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( + {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && ( <>
代理: diff --git a/web/src/pages/Task/index.js b/web/src/pages/Task/index.js new file mode 100644 index 000000000..aec3702f5 --- /dev/null +++ b/web/src/pages/Task/index.js @@ -0,0 +1,10 @@ +import React from 'react'; +import TaskLogsTable from "../../components/TaskLogsTable.js"; + +const Task = () => ( + <> + + +); + +export default Task;