diff --git a/common/constants.go b/common/constants.go index f5dbb3d5b..b8adfdbcc 100644 --- a/common/constants.go +++ b/common/constants.go @@ -208,8 +208,10 @@ const ( ChannelTypeAws = 33 ChannelTypeCohere = 34 ChannelTypeMiniMax = 35 + ChannelTypeSuno = 36 ChannelTypeDummy // this one is only for count, do not add any channel after this + ) var ChannelBaseURLs = []string{ diff --git a/constant/task.go b/constant/task.go new file mode 100644 index 000000000..1a68b8127 --- /dev/null +++ b/constant/task.go @@ -0,0 +1,18 @@ +package constant + +type TaskPlatform string + +const ( + TaskPlatformSuno TaskPlatform = "suno" + TaskPlatformMidjourney = "mj" +) + +const ( + SunoActionMusic = "MUSIC" + SunoActionLyrics = "LYRICS" +) + +var SunoModel2Action = map[string]string{ + "suno_music": SunoActionMusic, + "suno_lyrics": SunoActionLyrics, +} diff --git a/controller/relay.go b/controller/relay.go index a066e5d27..e7b819886 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) { "error": err, }) } + +func RelayTask(c *gin.Context) { + retryTimes := common.RetryTimes + channelId := c.GetInt("channel_id") + relayMode := c.GetInt("relay_mode") + group := c.GetString("group") + originalModel := c.GetString("original_model") + c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) + taskErr := taskRelayHandler(c, relayMode) + if taskErr == nil { + retryTimes = 0 + } + for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + if err != nil { + common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + break + } + channelId = channel.Id + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) + common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + + requestBody, err := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + taskErr = taskRelayHandler(c, relayMode) + } + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c.Request.Context(), retryLogStr) + } + if taskErr != nil { + if taskErr.StatusCode == http.StatusTooManyRequests { + taskErr.Message = "当前分组上游负载已饱和,请稍后再试" + } + c.JSON(taskErr.StatusCode, taskErr) + } +} + +func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { + var err *dto.TaskError + switch relayMode { + case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID: + err = relay.RelayTaskFetch(c, relayMode) + default: + err = relay.RelayTaskSubmit(c, relayMode) + } + return err +} + +func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool { + if taskErr == nil { + return false + } + if retryTimes <= 0 { + return false + } + if _, ok := c.Get("specific_channel_id"); ok { + return false + } + if taskErr.StatusCode == http.StatusTooManyRequests { + return true + } + if taskErr.StatusCode == 307 { + return true + } + if taskErr.StatusCode/100 == 5 { + // 超时不重试 + if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 { + return false + } + return true + } + if taskErr.StatusCode == http.StatusBadRequest { + return false + } + if taskErr.StatusCode == 408 { + // azure处理超时不重试 + return false + } + if taskErr.LocalError { + return false + } + if taskErr.StatusCode/100 == 2 { + return false + } + return true +} diff --git a/controller/task.go b/controller/task.go new file mode 100644 index 000000000..7b7d0223d --- /dev/null +++ b/controller/task.go @@ -0,0 +1,92 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "log" + "one-api/common" + "one-api/constant" + "one-api/model" + "strconv" + "time" +) + +func UpdateTaskBulk() { + //revocer + //imageModel := "midjourney" + for { + time.Sleep(time.Duration(15) * time.Second) + common.SysLog("任务进度轮询开始") + allTasks := model.GetAllUnFinishSyncTasks(500) + platformTask := make(map[constant.TaskPlatform][]*model.Task) + for _, t := range allTasks { + platformTask[t.Platform] = append(platformTask[t.Platform], t) + } + for platform, tasks := range platformTask { + UpdateTaskByPlatform(platform, tasks) + } + common.SysLog("任务进度轮询完成") + } +} + +func GetAllMidjourney(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + // 解析其他查询参数 + queryParams := model.TaskQueryParams{ + ChannelID: c.Query("channel_id"), + MjID: c.Query("mj_id"), + StartTimestamp: c.Query("start_timestamp"), + EndTimestamp: c.Query("end_timestamp"), + } + + logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams) + if logs == nil { + logs = make([]*model.Midjourney, 0) + } + if constant.MjForwardUrlEnabled { + for i, midjourney := range logs { + midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId + logs[i] = midjourney + } + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} + +func GetUserMidjourney(c *gin.Context) { + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + userId := c.GetInt("id") + log.Printf("userId = %d \n", userId) + + queryParams := model.TaskQueryParams{ + MjID: c.Query("mj_id"), + StartTimestamp: c.Query("start_timestamp"), + EndTimestamp: c.Query("end_timestamp"), + } + + logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams) + if logs == nil { + logs = make([]*model.Midjourney, 0) + } + if constant.MjForwardUrlEnabled { + for i, midjourney := range logs { + midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId + logs[i] = midjourney + } + } + c.JSON(200, gin.H{ + "success": true, + "message": "", + "data": logs, + }) +} diff --git a/dto/suno.go b/dto/suno.go new file mode 100644 index 000000000..a6bb3ebae --- /dev/null +++ b/dto/suno.go @@ -0,0 +1,129 @@ +package dto + +import ( + "encoding/json" +) + +type TaskData interface { + SunoDataResponse | []SunoDataResponse | string | any +} + +type SunoSubmitReq struct { + GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"` + Prompt string `json:"prompt,omitempty"` + Mv string `json:"mv,omitempty"` + Title string `json:"title,omitempty"` + Tags string `json:"tags,omitempty"` + ContinueAt float64 `json:"continue_at,omitempty"` + TaskID string `json:"task_id,omitempty"` + ContinueClipId string `json:"continue_clip_id,omitempty"` + MakeInstrumental bool `json:"make_instrumental"` +} + +type FetchReq struct { + IDs []string `json:"ids"` +} + +type SunoDataResponse struct { + TaskID string `json:"task_id" gorm:"type:varchar(50);index"` + Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode + Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed + FailReason string `json:"fail_reason"` + SubmitTime int64 `json:"submit_time" gorm:"index"` + StartTime int64 `json:"start_time" gorm:"index"` + FinishTime int64 `json:"finish_time" gorm:"index"` + Data json.RawMessage `json:"data" gorm:"type:json"` +} + +type SunoSong struct { + ID string `json:"id"` + VideoURL string `json:"video_url"` + AudioURL string `json:"audio_url"` + ImageURL string `json:"image_url"` + ImageLargeURL string `json:"image_large_url"` + MajorModelVersion string `json:"major_model_version"` + ModelName string `json:"model_name"` + Status string `json:"status"` + Title string `json:"title"` + Text string `json:"text"` + Metadata SunoMetadata `json:"metadata"` +} + +type SunoMetadata struct { + Tags string `json:"tags"` + Prompt string `json:"prompt"` + GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"` + AudioPromptID interface{} `json:"audio_prompt_id"` + Duration interface{} `json:"duration"` + ErrorType interface{} `json:"error_type"` + ErrorMessage interface{} `json:"error_message"` +} + +type SunoLyrics struct { + ID string `json:"id"` + Status string `json:"status"` + Title string `json:"title"` + Text string `json:"text"` +} + +const TaskSuccessCode = "success" + +type TaskResponse[T TaskData] struct { + Code string `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +func (t *TaskResponse[T]) IsSuccess() bool { + return t.Code == TaskSuccessCode +} + +type TaskDto struct { + TaskID string `json:"task_id"` // 第三方id,不一定有/ song id\ Task id + Action string `json:"action"` // 任务类型, song, lyrics, description-mode + Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed + FailReason string `json:"fail_reason"` + SubmitTime int64 `json:"submit_time"` + StartTime int64 `json:"start_time"` + FinishTime int64 `json:"finish_time"` + Progress string `json:"progress"` + Data json.RawMessage `json:"data"` +} + +type SunoGoAPISubmitReq struct { + CustomMode bool `json:"custom_mode"` + + Input SunoGoAPISubmitReqInput `json:"input"` + + NotifyHook string `json:"notify_hook,omitempty"` +} + +type SunoGoAPISubmitReqInput struct { + GptDescriptionPrompt string `json:"gpt_description_prompt"` + Prompt string `json:"prompt"` + Mv string `json:"mv"` + Title string `json:"title"` + Tags string `json:"tags"` + ContinueAt float64 `json:"continue_at"` + TaskID string `json:"task_id"` + ContinueClipId string `json:"continue_clip_id"` + MakeInstrumental bool `json:"make_instrumental"` +} + +type GoAPITaskResponse[T any] struct { + Code int `json:"code"` + Message string `json:"message"` + Data T `json:"data"` + ErrorMessage string `json:"error_message,omitempty"` +} + +type GoAPITaskResponseData struct { + TaskID string `json:"task_id"` +} + +type GoAPIFetchResponseData struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + Input string `json:"input"` + Clips map[string]SunoSong `json:"clips"` +} diff --git a/dto/task.go b/dto/task.go new file mode 100644 index 000000000..afc186b41 --- /dev/null +++ b/dto/task.go @@ -0,0 +1,10 @@ +package dto + +type TaskError struct { + Code string `json:"code"` + Message string `json:"message"` + Data any `json:"data"` + StatusCode int `json:"-"` + LocalError bool `json:"-"` + Error error `json:"-"` +} diff --git a/main.go b/main.go index 37c6a0a41..070fd1d10 100644 --- a/main.go +++ b/main.go @@ -20,10 +20,10 @@ import ( _ "net/http/pprof" ) -//go:embed web/dist +// /go:embed web/dist var buildFS embed.FS -//go:embed web/dist/index.html +// /go:embed web/dist/index.html var indexPage []byte func main() { diff --git a/middleware/distributor.go b/middleware/distributor.go index ae5707fb5..94079d39b 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { modelRequest.Model = midjourneyModel } c.Set("relay_mode", relayMode) + } else if strings.Contains(c.Request.URL.Path, "/suno/") { + relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path) + if relayMode == relayconstant.RelayModeSunoFetch || + relayMode == relayconstant.RelayModeSunoFetchByID { + shouldSelectChannel = false + } else { + modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action")) + modelRequest.Model = modelName + } + c.Set("platform", constant.TaskPlatformSuno) + c.Set("relay_mode", relayMode) } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { err = common.UnmarshalBodyReusable(c, &modelRequest) } diff --git a/model/task.go b/model/task.go new file mode 100644 index 000000000..df221edfc --- /dev/null +++ b/model/task.go @@ -0,0 +1,304 @@ +package model + +import ( + "database/sql/driver" + "encoding/json" + "one-api/constant" + commonRelay "one-api/relay/common" + "time" +) + +type TaskStatus string + +const ( + TaskStatusNotStart TaskStatus = "NOT_START" + TaskStatusSubmitted = "SUBMITTED" + TaskStatusQueued = "QUEUED" + TaskStatusInProgress = "IN_PROGRESS" + TaskStatusFailure = "FAILURE" + TaskStatusSuccess = "SUCCESS" + TaskStatusUnknown = "UNKNOWN" +) + +type Task struct { + ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"` + CreatedAt int64 `json:"created_at" gorm:"index"` + UpdatedAt int64 `json:"updated_at"` + TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id,不一定有/ song id\ Task id + Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台 + UserId int `json:"user_id" gorm:"index"` + ChannelId int `json:"channel_id" gorm:"index"` + Quota int `json:"quota"` + Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode + Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态 + FailReason string `json:"fail_reason"` + SubmitTime int64 `json:"submit_time" gorm:"index"` + StartTime int64 `json:"start_time" gorm:"index"` + FinishTime int64 `json:"finish_time" gorm:"index"` + Progress string `json:"progress" gorm:"type:varchar(20);index"` + Properties Properties `json:"properties" gorm:"type:json"` + + Data json.RawMessage `json:"data" gorm:"type:json"` +} + +func (t *Task) SetData(data any) { + b, _ := json.Marshal(data) + t.Data = json.RawMessage(b) +} + +func (t *Task) GetData(v any) error { + err := json.Unmarshal(t.Data, &v) + return err +} + +type Properties struct { + Input string `json:"input"` +} + +func (m *Properties) Scan(val interface{}) error { + bytesValue, _ := val.([]byte) + return json.Unmarshal(bytesValue, m) +} + +func (m Properties) Value() (driver.Value, error) { + return json.Marshal(m) +} + +// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段 +type SyncTaskQueryParams struct { + Platform constant.TaskPlatform + ChannelID string + TaskID string + UserID string + Action string + Status string + StartTimestamp int64 + EndTimestamp int64 + UserIDs []int +} + +func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task { + t := &Task{ + UserId: relayInfo.UserId, + SubmitTime: time.Now().Unix(), + Status: TaskStatusNotStart, + Progress: "0%", + ChannelId: relayInfo.ChannelId, + Platform: platform, + } + return t +} + +func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { + var tasks []*Task + var err error + + // 初始化查询构建器 + query := DB.Where("user_id = ?", userId) + + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.Platform != "" { + query = query.Where("platform = ?", queryParams.Platform) + } + if queryParams.StartTimestamp != 0 { + // 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析 + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + + // 获取数据 + err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error + if err != nil { + return nil + } + + return tasks +} + +func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task { + var tasks []*Task + var err error + + // 初始化查询构建器 + query := DB + + // 添加过滤条件 + if queryParams.ChannelID != "" { + query = query.Where("channel_id = ?", queryParams.ChannelID) + } + if queryParams.Platform != "" { + query = query.Where("platform = ?", queryParams.Platform) + } + if queryParams.UserID != "" { + query = query.Where("user_id = ?", queryParams.UserID) + } + if len(queryParams.UserIDs) != 0 { + query = query.Where("user_id in (?)", queryParams.UserIDs) + } + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + + // 获取数据 + err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error + if err != nil { + return nil + } + + return tasks +} + +func GetAllUnFinishSyncTasks(limit int) []*Task { + var tasks []*Task + var err error + // get all tasks progress is not 100% + err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error + if err != nil { + return nil + } + return tasks +} + +func GetByOnlyTaskId(taskId string) (*Task, bool, error) { + if taskId == "" { + return nil, false, nil + } + var task *Task + var err error + err = DB.Where("task_id = ?", taskId).First(&task).Error + exist, err := RecordExist(err) + if err != nil { + return nil, false, err + } + return task, exist, err +} + +func GetByTaskId(userId int, taskId string) (*Task, bool, error) { + if taskId == "" { + return nil, false, nil + } + var task *Task + var err error + err = DB.Where("user_id = ? and task_id = ?", userId, taskId). + First(&task).Error + exist, err := RecordExist(err) + if err != nil { + return nil, false, err + } + return task, exist, err +} + +func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) { + if len(taskIds) == 0 { + return nil, nil + } + var task []*Task + var err error + err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds). + Find(&task).Error + if err != nil { + return nil, err + } + return task, nil +} + +func TaskUpdateProgress(id int64, progress string) error { + return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error +} + +func (Task *Task) Insert() error { + var err error + err = DB.Create(Task).Error + return err +} + +func (Task *Task) Update() error { + var err error + err = DB.Save(Task).Error + return err +} + +func TaskBulkUpdate(TaskIds []string, params map[string]any) error { + if len(TaskIds) == 0 { + return nil + } + return DB.Model(&Task{}). + Where("task_id in (?)", TaskIds). + Updates(params).Error +} + +func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error { + if len(taskIDs) == 0 { + return nil + } + return DB.Model(&Task{}). + Where("id in (?)", taskIDs). + Updates(params).Error +} + +func TaskBulkUpdateByID(ids []int64, params map[string]any) error { + if len(ids) == 0 { + return nil + } + return DB.Model(&Task{}). + Where("id in (?)", ids). + Updates(params).Error +} + +type TaskQuotaUsage struct { + Mode string `json:"mode"` + Count float64 `json:"count"` +} + +func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) { + query := DB.Model(Task{}) + // 添加过滤条件 + if queryParams.ChannelID != "" { + query = query.Where("channel_id = ?", queryParams.ChannelID) + } + if queryParams.UserID != "" { + query = query.Where("user_id = ?", queryParams.UserID) + } + if len(queryParams.UserIDs) != 0 { + query = query.Where("user_id in (?)", queryParams.UserIDs) + } + if queryParams.TaskID != "" { + query = query.Where("task_id = ?", queryParams.TaskID) + } + if queryParams.Action != "" { + query = query.Where("action = ?", queryParams.Action) + } + if queryParams.Status != "" { + query = query.Where("status = ?", queryParams.Status) + } + if queryParams.StartTimestamp != 0 { + query = query.Where("submit_time >= ?", queryParams.StartTimestamp) + } + if queryParams.EndTimestamp != 0 { + query = query.Where("submit_time <= ?", queryParams.EndTimestamp) + } + err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error + return stat, err +} diff --git a/model/utils.go b/model/utils.go index 1c28340bf..44bfbb9e2 100644 --- a/model/utils.go +++ b/model/utils.go @@ -1,6 +1,8 @@ package model import ( + "errors" + "gorm.io/gorm" "one-api/common" "sync" "time" @@ -75,3 +77,13 @@ func batchUpdate() { } common.SysLog("batch update finished") } + +func RecordExist(err error) (bool, error) { + if err == nil { + return true, nil + } + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err +} diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index d3886d51e..6029abbf1 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -19,3 +19,21 @@ type Adaptor interface { GetModelList() []string GetChannelName() string } + +type TaskAdaptor interface { + Init(info *relaycommon.TaskRelayInfo) + + ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError + + BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) + BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error + BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) + + DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError) + + GetModelList() []string + GetChannelName() string + + // FetchTask +} diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ef82645d9..ab1131fe1 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { _ = c.Request.Body.Close() return resp, nil } + +func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.BuildRequestURL(info) + if err != nil { + return nil, err + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(requestBody), nil + } + + err = a.BuildRequestHeader(c, req, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go new file mode 100644 index 000000000..ff7261183 --- /dev/null +++ b/relay/channel/task/suno/adaptor.go @@ -0,0 +1,147 @@ +package suno + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +type TaskAdaptor struct { + ChannelType int + Action string +} + +func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { + a.ChannelType = info.ChannelType + +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { + action := strings.ToUpper(c.Param("action")) + + var sunoRequest *dto.SunoSubmitReq + err := common.UnmarshalBodyReusable(c, &sunoRequest) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + err = actionValidate(c, sunoRequest, action) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) + return + } + + if sunoRequest.ContinueClipId != "" { + if sunoRequest.TaskID == "" { + taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest) + return + } + info.OriginTaskID = sunoRequest.TaskID + } + + a.Action = info.Action + c.Set("task_request", sunoRequest) + return nil +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { + baseURL := common.ChannelBaseURLs[info.ChannelType] + if info.BaseUrl != "" { + baseURL = info.BaseUrl + } + fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/submit/"+info.Action) + return fullRequestURL, nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Header.Set("Authorization", "Bearer "+info.ApiKey) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { + sunoRequest, ok := c.Get("task_request") + if !ok { + err := common.UnmarshalBodyReusable(c, &sunoRequest) + if err != nil { + return nil, err + } + } + data, err := json.Marshal(sunoRequest) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + var sunoResponse dto.TaskResponse[string] + err = json.Unmarshal(responseBody, &sunoResponse) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + if !sunoResponse.IsSuccess() { + taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError) + return + } + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody)) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return + } + + return sunoResponse.Data, nil, nil +} + +func (a *TaskAdaptor) GetModelList() []string { + return ModelList +} + +func (a *TaskAdaptor) GetChannelName() string { + return ChannelName +} + +func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) { + switch action { + case constant.SunoActionMusic: + if sunoRequest.Mv == "" { + sunoRequest.Mv = "chirp-v3-0" + } + case constant.SunoActionLyrics: + if sunoRequest.Prompt == "" { + err = fmt.Errorf("prompt_empty") + return + } + default: + err = fmt.Errorf("invalid_action") + } + return +} diff --git a/relay/channel/task/suno/models.go b/relay/channel/task/suno/models.go new file mode 100644 index 000000000..967cf1b1d --- /dev/null +++ b/relay/channel/task/suno/models.go @@ -0,0 +1,7 @@ +package suno + +var ModelList = []string{ + "suno_music", "suno_lyrics", +} + +var ChannelName = "suno" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b40352e88..f93d36a27 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) { func (info *RelayInfo) SetIsStream(isStream bool) { info.IsStream = isStream } + +type TaskRelayInfo struct { + ChannelType int + ChannelId int + TokenId int + UserId int + Group string + StartTime time.Time + ApiType int + RelayMode int + UpstreamModelName string + RequestURLPath string + ApiKey string + BaseUrl string + + Action string + OriginTaskID string + + ConsumeQuota bool +} + +func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { + channelType := c.GetInt("channel") + channelId := c.GetInt("channel_id") + + tokenId := c.GetInt("token_id") + userId := c.GetInt("id") + group := c.GetString("group") + startTime := time.Now() + + apiType, _ := constant.ChannelType2APIType(channelType) + + info := &TaskRelayInfo{ + RelayMode: constant.Path2RelayMode(c.Request.URL.Path), + BaseUrl: c.GetString("base_url"), + RequestURLPath: c.Request.URL.String(), + ChannelType: channelType, + ChannelId: channelId, + TokenId: tokenId, + UserId: userId, + Group: group, + StartTime: startTime, + ApiType: apiType, + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + } + if info.BaseUrl == "" { + info.BaseUrl = common.ChannelBaseURLs[channelType] + } + return info +} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 2e94bc031..fa19f5053 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -1,6 +1,9 @@ package constant -import "strings" +import ( + "net/http" + "strings" +) const ( RelayModeUnknown = iota @@ -26,6 +29,9 @@ const ( RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + RelayModeSunoFetch + RelayModeSunoFetchByID + RelayModeSunoSubmit ) func Path2RelayMode(path string) int { @@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int { } return relayMode } + +func Path2RelaySuno(method, path string) int { + relayMode := RelayModeUnknown + if method == http.MethodPost && strings.HasSuffix(path, "/fetch") { + relayMode = RelayModeSunoFetch + } else if method == http.MethodGet && strings.Contains(path, "/fetch/") { + relayMode = RelayModeSunoFetchByID + } else if strings.Contains(path, "/submit/") { + relayMode = RelayModeSunoSubmit + } + return relayMode +} diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index cf6305412..bfa13f4a9 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,6 +1,7 @@ package relay import ( + commonconstant "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" "one-api/relay/channel/aws" @@ -12,6 +13,7 @@ import ( "one-api/relay/channel/openai" "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" + "one-api/relay/channel/task/suno" "one-api/relay/channel/tencent" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" @@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor { } return nil } + +func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor { + switch platform { + //case constant.APITypeAIProxyLibrary: + // return &aiproxy.Adaptor{} + case commonconstant.TaskPlatformSuno: + return &suno.TaskAdaptor{} + } + return nil +} diff --git a/relay/relay_task.go b/relay/relay_task.go new file mode 100644 index 000000000..47d8a5c2f --- /dev/null +++ b/relay/relay_task.go @@ -0,0 +1,242 @@ +package relay + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/constant" + "one-api/dto" + "one-api/model" + relaycommon "one-api/relay/common" + relayconstant "one-api/relay/constant" + "one-api/service" +) + +/* +Task 任务通过平台、Action 区分任务 +*/ +func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { + platform := constant.TaskPlatform(c.GetString("platform")) + relayInfo := relaycommon.GenTaskRelayInfo(c) + + adaptor := GetTaskAdaptor(platform) + if adaptor == nil { + return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + // get & validate taskRequest 获取并验证文本请求 + taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo) + if taskErr != nil { + return + } + + modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action) + modelPrice, success := common.GetModelPrice(modelName, true) + if !success { + defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName] + if !ok { + modelPrice = 0.1 + } else { + modelPrice = defaultPrice + } + } + + // 预扣 + groupRatio := common.GetGroupRatio(relayInfo.Group) + ratio := modelPrice * groupRatio + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + return + } + quota := int(ratio * common.QuotaPerUnit) + if userQuota-quota < 0 { + taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden) + return + } + + if relayInfo.OriginTaskID != "" { + originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + if originTask.ChannelId != relayInfo.ChannelId { + channel, err := model.GetChannelById(originTask.ChannelId, true) + if err != nil { + taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) + return + } + if channel.Status != common.ChannelStatusEnabled { + return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest) + } + c.Set("base_url", channel.GetBaseURL()) + c.Set("channel_id", originTask.ChannelId) + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + + relayInfo.BaseUrl = channel.GetBaseURL() + relayInfo.ChannelId = originTask.ChannelId + } + } + + // build body + requestBody, err := adaptor.BuildRequestBody(c, relayInfo) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) + return + } + // do request + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + return + } + // handle response + if resp != nil && resp.StatusCode != http.StatusOK { + responseBody, _ := io.ReadAll(resp.Body) + taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode) + return + } + + defer func(ctx context.Context) { + // release quota + if relayInfo.ConsumeQuota && taskErr == nil { + err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(relayInfo.UserId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action) + other := make(map[string]interface{}) + other["model_price"] = modelPrice + other["group_ratio"] = groupRatio + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other) + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + } + }(c.Request.Context()) + + taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) + if taskErr != nil { + return + } + relayInfo.ConsumeQuota = true + // insert task + task := model.InitTask(constant.TaskPlatformSuno, relayInfo) + task.TaskID = taskID + task.Quota = quota + task.Data = taskData + err = task.Insert() + if err != nil { + taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) + return + } + return nil +} + +var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ + relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, + relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, +} + +func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { + respBuilder, ok := fetchRespBuilders[relayMode] + if !ok { + taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest) + } + + respBody, taskErr := respBuilder(c) + if taskErr != nil { + return taskErr + } + + c.Writer.Header().Set("Content-Type", "application/json") + _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) + return + } + return +} + +func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + userId := c.GetInt("id") + var condition = struct { + IDs []any `json:"ids"` + Action string `json:"action"` + }{} + err := c.BindJSON(&condition) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest) + return + } + var tasks []any + if len(condition.IDs) > 0 { + taskModels, err := model.GetByTaskIds(userId, condition.IDs) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError) + return + } + for _, task := range taskModels { + tasks = append(tasks, TaskModel2Dto(task)) + } + } else { + tasks = make([]any, 0) + } + respBody, err = json.Marshal(dto.TaskResponse[[]any]{ + Code: "success", + Data: tasks, + }) + return +} + +func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { + taskId := c.Param("id") + userId := c.GetInt("id") + + originTask, exist, err := model.GetByTaskId(userId, taskId) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) + return + } + if !exist { + taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) + return + } + + respBody, err = json.Marshal(dto.TaskResponse[any]{ + Code: "success", + Data: TaskModel2Dto(originTask), + }) + return +} + +func TaskModel2Dto(task *model.Task) *dto.TaskDto { + return &dto.TaskDto{ + TaskID: task.TaskID, + Action: task.Action, + Status: string(task.Status), + FailReason: task.FailReason, + SubmitTime: task.SubmitTime, + StartTime: task.StartTime, + FinishTime: task.FinishTime, + Progress: task.Progress, + Data: task.Data, + } +} diff --git a/router/relay-router.go b/router/relay-router.go index 2d8e7b38d..3ad9e3774 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) { relayMjModeRouter := router.Group("/:mode/mj") registerMjRouterGroup(relayMjModeRouter) //relayMjRouter.Use() + + relaySunoRouter := router.Group("/suno") + relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute()) + { + relaySunoRouter.POST("/submit/:action", controller.RelayTask) + relaySunoRouter.POST("/fetch", controller.RelayTask) + relaySunoRouter.GET("/fetch/:id", controller.RelayTask) + } + } func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { diff --git a/service/error.go b/service/error.go index 4b00f3774..0f6d472fb 100644 --- a/service/error.go +++ b/service/error.go @@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping openaiErr.StatusCode = intCode } } + +func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError { + openaiErr := TaskErrorWrapper(err, code, statusCode) + openaiErr.LocalError = true + return openaiErr +} + +func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { + text := err.Error() + + // 定义一个正则表达式匹配URL + if strings.Contains(text, "Post") || strings.Contains(text, "dial") { + common.SysLog(fmt.Sprintf("error: %s", text)) + text = "请求上游地址失败" + } + //避免暴露内部错误 + + taskError := &dto.TaskError{ + Code: code, + Message: text, + StatusCode: statusCode, + Error: err, + } + + return taskError +} diff --git a/service/task.go b/service/task.go new file mode 100644 index 000000000..c2501fe28 --- /dev/null +++ b/service/task.go @@ -0,0 +1,10 @@ +package service + +import ( + "one-api/constant" + "strings" +) + +func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string { + return strings.ToLower(string(platform)) + "_" + strings.ToLower(action) +}