feat: suno api 支持

This commit is contained in:
Xiangyuan Liu
2024-06-12 20:37:42 +08:00
parent c1040afed9
commit 1e8abc7027
21 changed files with 1235 additions and 3 deletions

View File

@@ -208,8 +208,10 @@ const (
ChannelTypeAws = 33
ChannelTypeCohere = 34
ChannelTypeMiniMax = 35
ChannelTypeSuno = 36
ChannelTypeDummy // this one is only for count, do not add any channel after this
)
var ChannelBaseURLs = []string{

18
constant/task.go Normal file
View File

@@ -0,0 +1,18 @@
package constant
type TaskPlatform string
const (
TaskPlatformSuno TaskPlatform = "suno"
TaskPlatformMidjourney = "mj"
)
const (
SunoActionMusic = "MUSIC"
SunoActionLyrics = "LYRICS"
)
var SunoModel2Action = map[string]string{
"suno_music": SunoActionMusic,
"suno_lyrics": SunoActionLyrics,
}

View File

@@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) {
"error": err,
})
}
func RelayTask(c *gin.Context) {
retryTimes := common.RetryTimes
channelId := c.GetInt("channel_id")
relayMode := c.GetInt("relay_mode")
group := c.GetString("group")
originalModel := c.GetString("original_model")
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
taskErr := taskRelayHandler(c, relayMode)
if taskErr == nil {
retryTimes = 0
}
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
taskErr = taskRelayHandler(c, relayMode)
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {
taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
}
c.JSON(taskErr.StatusCode, taskErr)
}
}
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
var err *dto.TaskError
switch relayMode {
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
err = relay.RelayTaskFetch(c, relayMode)
default:
err = relay.RelayTaskSubmit(c, relayMode)
}
return err
}
func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
if taskErr == nil {
return false
}
if retryTimes <= 0 {
return false
}
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if taskErr.StatusCode == http.StatusTooManyRequests {
return true
}
if taskErr.StatusCode == 307 {
return true
}
if taskErr.StatusCode/100 == 5 {
// 超时不重试
if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
return false
}
return true
}
if taskErr.StatusCode == http.StatusBadRequest {
return false
}
if taskErr.StatusCode == 408 {
// azure处理超时不重试
return false
}
if taskErr.LocalError {
return false
}
if taskErr.StatusCode/100 == 2 {
return false
}
return true
}

92
controller/task.go Normal file
View File

@@ -0,0 +1,92 @@
package controller
import (
"github.com/gin-gonic/gin"
"log"
"one-api/common"
"one-api/constant"
"one-api/model"
"strconv"
"time"
)
func UpdateTaskBulk() {
//revocer
//imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
common.SysLog("任务进度轮询开始")
allTasks := model.GetAllUnFinishSyncTasks(500)
platformTask := make(map[constant.TaskPlatform][]*model.Task)
for _, t := range allTasks {
platformTask[t.Platform] = append(platformTask[t.Platform], t)
}
for platform, tasks := range platformTask {
UpdateTaskByPlatform(platform, tasks)
}
common.SysLog("任务进度轮询完成")
}
}
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
// 解析其他查询参数
queryParams := model.TaskQueryParams{
ChannelID: c.Query("channel_id"),
MjID: c.Query("mj_id"),
StartTimestamp: c.Query("start_timestamp"),
EndTimestamp: c.Query("end_timestamp"),
}
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
log.Printf("userId = %d \n", userId)
queryParams := model.TaskQueryParams{
MjID: c.Query("mj_id"),
StartTimestamp: c.Query("start_timestamp"),
EndTimestamp: c.Query("end_timestamp"),
}
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}

129
dto/suno.go Normal file
View File

@@ -0,0 +1,129 @@
package dto
import (
"encoding/json"
)
type TaskData interface {
SunoDataResponse | []SunoDataResponse | string | any
}
type SunoSubmitReq struct {
GptDescriptionPrompt string `json:"gpt_description_prompt,omitempty"`
Prompt string `json:"prompt,omitempty"`
Mv string `json:"mv,omitempty"`
Title string `json:"title,omitempty"`
Tags string `json:"tags,omitempty"`
ContinueAt float64 `json:"continue_at,omitempty"`
TaskID string `json:"task_id,omitempty"`
ContinueClipId string `json:"continue_clip_id,omitempty"`
MakeInstrumental bool `json:"make_instrumental"`
}
type FetchReq struct {
IDs []string `json:"ids"`
}
type SunoDataResponse struct {
TaskID string `json:"task_id" gorm:"type:varchar(50);index"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
Status string `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
type SunoSong struct {
ID string `json:"id"`
VideoURL string `json:"video_url"`
AudioURL string `json:"audio_url"`
ImageURL string `json:"image_url"`
ImageLargeURL string `json:"image_large_url"`
MajorModelVersion string `json:"major_model_version"`
ModelName string `json:"model_name"`
Status string `json:"status"`
Title string `json:"title"`
Text string `json:"text"`
Metadata SunoMetadata `json:"metadata"`
}
type SunoMetadata struct {
Tags string `json:"tags"`
Prompt string `json:"prompt"`
GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
AudioPromptID interface{} `json:"audio_prompt_id"`
Duration interface{} `json:"duration"`
ErrorType interface{} `json:"error_type"`
ErrorMessage interface{} `json:"error_message"`
}
type SunoLyrics struct {
ID string `json:"id"`
Status string `json:"status"`
Title string `json:"title"`
Text string `json:"text"`
}
const TaskSuccessCode = "success"
type TaskResponse[T TaskData] struct {
Code string `json:"code"`
Message string `json:"message"`
Data T `json:"data"`
}
func (t *TaskResponse[T]) IsSuccess() bool {
return t.Code == TaskSuccessCode
}
type TaskDto struct {
TaskID string `json:"task_id"` // 第三方id不一定有/ song id\ Task id
Action string `json:"action"` // 任务类型, song, lyrics, description-mode
Status string `json:"status"` // 任务状态, submitted, queueing, processing, success, failed
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time"`
StartTime int64 `json:"start_time"`
FinishTime int64 `json:"finish_time"`
Progress string `json:"progress"`
Data json.RawMessage `json:"data"`
}
type SunoGoAPISubmitReq struct {
CustomMode bool `json:"custom_mode"`
Input SunoGoAPISubmitReqInput `json:"input"`
NotifyHook string `json:"notify_hook,omitempty"`
}
type SunoGoAPISubmitReqInput struct {
GptDescriptionPrompt string `json:"gpt_description_prompt"`
Prompt string `json:"prompt"`
Mv string `json:"mv"`
Title string `json:"title"`
Tags string `json:"tags"`
ContinueAt float64 `json:"continue_at"`
TaskID string `json:"task_id"`
ContinueClipId string `json:"continue_clip_id"`
MakeInstrumental bool `json:"make_instrumental"`
}
type GoAPITaskResponse[T any] struct {
Code int `json:"code"`
Message string `json:"message"`
Data T `json:"data"`
ErrorMessage string `json:"error_message,omitempty"`
}
type GoAPITaskResponseData struct {
TaskID string `json:"task_id"`
}
type GoAPIFetchResponseData struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
Input string `json:"input"`
Clips map[string]SunoSong `json:"clips"`
}

10
dto/task.go Normal file
View File

@@ -0,0 +1,10 @@
package dto
type TaskError struct {
Code string `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
StatusCode int `json:"-"`
LocalError bool `json:"-"`
Error error `json:"-"`
}

View File

@@ -20,10 +20,10 @@ import (
_ "net/http/pprof"
)
//go:embed web/dist
// /go:embed web/dist
var buildFS embed.FS
//go:embed web/dist/index.html
// /go:embed web/dist/index.html
var indexPage []byte
func main() {

View File

@@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
modelRequest.Model = midjourneyModel
}
c.Set("relay_mode", relayMode)
} else if strings.Contains(c.Request.URL.Path, "/suno/") {
relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
if relayMode == relayconstant.RelayModeSunoFetch ||
relayMode == relayconstant.RelayModeSunoFetchByID {
shouldSelectChannel = false
} else {
modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
modelRequest.Model = modelName
}
c.Set("platform", constant.TaskPlatformSuno)
c.Set("relay_mode", relayMode)
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}

304
model/task.go Normal file
View File

@@ -0,0 +1,304 @@
package model
import (
"database/sql/driver"
"encoding/json"
"one-api/constant"
commonRelay "one-api/relay/common"
"time"
)
type TaskStatus string
const (
TaskStatusNotStart TaskStatus = "NOT_START"
TaskStatusSubmitted = "SUBMITTED"
TaskStatusQueued = "QUEUED"
TaskStatusInProgress = "IN_PROGRESS"
TaskStatusFailure = "FAILURE"
TaskStatusSuccess = "SUCCESS"
TaskStatusUnknown = "UNKNOWN"
)
type Task struct {
ID int64 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
CreatedAt int64 `json:"created_at" gorm:"index"`
UpdatedAt int64 `json:"updated_at"`
TaskID string `json:"task_id" gorm:"type:varchar(50);index"` // 第三方id不一定有/ song id\ Task id
Platform constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
UserId int `json:"user_id" gorm:"index"`
ChannelId int `json:"channel_id" gorm:"index"`
Quota int `json:"quota"`
Action string `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
Status TaskStatus `json:"status" gorm:"type:varchar(20);index"` // 任务状态
FailReason string `json:"fail_reason"`
SubmitTime int64 `json:"submit_time" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"index"`
FinishTime int64 `json:"finish_time" gorm:"index"`
Progress string `json:"progress" gorm:"type:varchar(20);index"`
Properties Properties `json:"properties" gorm:"type:json"`
Data json.RawMessage `json:"data" gorm:"type:json"`
}
func (t *Task) SetData(data any) {
b, _ := json.Marshal(data)
t.Data = json.RawMessage(b)
}
func (t *Task) GetData(v any) error {
err := json.Unmarshal(t.Data, &v)
return err
}
type Properties struct {
Input string `json:"input"`
}
func (m *Properties) Scan(val interface{}) error {
bytesValue, _ := val.([]byte)
return json.Unmarshal(bytesValue, m)
}
func (m Properties) Value() (driver.Value, error) {
return json.Marshal(m)
}
// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
type SyncTaskQueryParams struct {
Platform constant.TaskPlatform
ChannelID string
TaskID string
UserID string
Action string
Status string
StartTimestamp int64
EndTimestamp int64
UserIDs []int
}
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
t := &Task{
UserId: relayInfo.UserId,
SubmitTime: time.Now().Unix(),
Status: TaskStatusNotStart,
Progress: "0%",
ChannelId: relayInfo.ChannelId,
Platform: platform,
}
return t
}
func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
var tasks []*Task
var err error
// 初始化查询构建器
query := DB.Where("user_id = ?", userId)
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.StartTimestamp != 0 {
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
// 获取数据
err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
var tasks []*Task
var err error
// 初始化查询构建器
query := DB
// 添加过滤条件
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.Platform != "" {
query = query.Where("platform = ?", queryParams.Platform)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
// 获取数据
err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetAllUnFinishSyncTasks(limit int) []*Task {
var tasks []*Task
var err error
// get all tasks progress is not 100%
err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
if err != nil {
return nil
}
return tasks
}
func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
if taskId == "" {
return nil, false, nil
}
var task *Task
var err error
err = DB.Where("task_id = ?", taskId).First(&task).Error
exist, err := RecordExist(err)
if err != nil {
return nil, false, err
}
return task, exist, err
}
func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
if taskId == "" {
return nil, false, nil
}
var task *Task
var err error
err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
First(&task).Error
exist, err := RecordExist(err)
if err != nil {
return nil, false, err
}
return task, exist, err
}
func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
if len(taskIds) == 0 {
return nil, nil
}
var task []*Task
var err error
err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
Find(&task).Error
if err != nil {
return nil, err
}
return task, nil
}
func TaskUpdateProgress(id int64, progress string) error {
return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
}
func (Task *Task) Insert() error {
var err error
err = DB.Create(Task).Error
return err
}
func (Task *Task) Update() error {
var err error
err = DB.Save(Task).Error
return err
}
func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
if len(TaskIds) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("task_id in (?)", TaskIds).
Updates(params).Error
}
func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
if len(taskIDs) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("id in (?)", taskIDs).
Updates(params).Error
}
func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
if len(ids) == 0 {
return nil
}
return DB.Model(&Task{}).
Where("id in (?)", ids).
Updates(params).Error
}
type TaskQuotaUsage struct {
Mode string `json:"mode"`
Count float64 `json:"count"`
}
func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
query := DB.Model(Task{})
// 添加过滤条件
if queryParams.ChannelID != "" {
query = query.Where("channel_id = ?", queryParams.ChannelID)
}
if queryParams.UserID != "" {
query = query.Where("user_id = ?", queryParams.UserID)
}
if len(queryParams.UserIDs) != 0 {
query = query.Where("user_id in (?)", queryParams.UserIDs)
}
if queryParams.TaskID != "" {
query = query.Where("task_id = ?", queryParams.TaskID)
}
if queryParams.Action != "" {
query = query.Where("action = ?", queryParams.Action)
}
if queryParams.Status != "" {
query = query.Where("status = ?", queryParams.Status)
}
if queryParams.StartTimestamp != 0 {
query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
}
if queryParams.EndTimestamp != 0 {
query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
}
err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
return stat, err
}

View File

@@ -1,6 +1,8 @@
package model
import (
"errors"
"gorm.io/gorm"
"one-api/common"
"sync"
"time"
@@ -75,3 +77,13 @@ func batchUpdate() {
}
common.SysLog("batch update finished")
}
func RecordExist(err error) (bool, error) {
if err == nil {
return true, nil
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}

View File

@@ -19,3 +19,21 @@ type Adaptor interface {
GetModelList() []string
GetChannelName() string
}
type TaskAdaptor interface {
Init(info *relaycommon.TaskRelayInfo)
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
GetModelList() []string
GetChannelName() string
// FetchTask
}

View File

@@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
_ = c.Request.Body.Close()
return resp, nil
}
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.BuildRequestURL(info)
if err != nil {
return nil, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
err = a.BuildRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}

View File

@@ -0,0 +1,147 @@
package suno
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
)
type TaskAdaptor struct {
ChannelType int
Action string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
action := strings.ToUpper(c.Param("action"))
var sunoRequest *dto.SunoSubmitReq
err := common.UnmarshalBodyReusable(c, &sunoRequest)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
err = actionValidate(c, sunoRequest, action)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if sunoRequest.ContinueClipId != "" {
if sunoRequest.TaskID == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
return
}
info.OriginTaskID = sunoRequest.TaskID
}
a.Action = info.Action
c.Set("task_request", sunoRequest)
return nil
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
baseURL := common.ChannelBaseURLs[info.ChannelType]
if info.BaseUrl != "" {
baseURL = info.BaseUrl
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/submit/"+info.Action)
return fullRequestURL, nil
}
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
sunoRequest, ok := c.Get("task_request")
if !ok {
err := common.UnmarshalBodyReusable(c, &sunoRequest)
if err != nil {
return nil, err
}
}
data, err := json.Marshal(sunoRequest)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
var sunoResponse dto.TaskResponse[string]
err = json.Unmarshal(responseBody, &sunoResponse)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if !sunoResponse.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
return
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
if err != nil {
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
}
return sunoResponse.Data, nil, nil
}
func (a *TaskAdaptor) GetModelList() []string {
return ModelList
}
func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
switch action {
case constant.SunoActionMusic:
if sunoRequest.Mv == "" {
sunoRequest.Mv = "chirp-v3-0"
}
case constant.SunoActionLyrics:
if sunoRequest.Prompt == "" {
err = fmt.Errorf("prompt_empty")
return
}
default:
err = fmt.Errorf("invalid_action")
}
return
}

View File

@@ -0,0 +1,7 @@
package suno
var ModelList = []string{
"suno_music", "suno_lyrics",
}
var ChannelName = "suno"

View File

@@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) {
func (info *RelayInfo) SetIsStream(isStream bool) {
info.IsStream = isStream
}
type TaskRelayInfo struct {
ChannelType int
ChannelId int
TokenId int
UserId int
Group string
StartTime time.Time
ApiType int
RelayMode int
UpstreamModelName string
RequestURLPath string
ApiKey string
BaseUrl string
Action string
OriginTaskID string
ConsumeQuota bool
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
group := c.GetString("group")
startTime := time.Now()
apiType, _ := constant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
ChannelId: channelId,
TokenId: tokenId,
UserId: userId,
Group: group,
StartTime: startTime,
ApiType: apiType,
ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
}
if info.BaseUrl == "" {
info.BaseUrl = common.ChannelBaseURLs[channelType]
}
return info
}

View File

@@ -1,6 +1,9 @@
package constant
import "strings"
import (
"net/http"
"strings"
)
const (
RelayModeUnknown = iota
@@ -26,6 +29,9 @@ const (
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeSunoFetch
RelayModeSunoFetchByID
RelayModeSunoSubmit
)
func Path2RelayMode(path string) int {
@@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int {
}
return relayMode
}
func Path2RelaySuno(method, path string) int {
relayMode := RelayModeUnknown
if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
relayMode = RelayModeSunoFetch
} else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
relayMode = RelayModeSunoFetchByID
} else if strings.Contains(path, "/submit/") {
relayMode = RelayModeSunoSubmit
}
return relayMode
}

View File

@@ -1,6 +1,7 @@
package relay
import (
commonconstant "one-api/constant"
"one-api/relay/channel"
"one-api/relay/channel/ali"
"one-api/relay/channel/aws"
@@ -12,6 +13,7 @@ import (
"one-api/relay/channel/openai"
"one-api/relay/channel/palm"
"one-api/relay/channel/perplexity"
"one-api/relay/channel/task/suno"
"one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu"
@@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor {
}
return nil
}
func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
switch platform {
//case constant.APITypeAIProxyLibrary:
// return &aiproxy.Adaptor{}
case commonconstant.TaskPlatformSuno:
return &suno.TaskAdaptor{}
}
return nil
}

242
relay/relay_task.go Normal file
View File

@@ -0,0 +1,242 @@
package relay
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
)
/*
Task 任务通过平台、Action 区分任务
*/
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
platform := constant.TaskPlatform(c.GetString("platform"))
relayInfo := relaycommon.GenTaskRelayInfo(c)
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
}
adaptor.Init(relayInfo)
// get & validate taskRequest 获取并验证文本请求
taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
if taskErr != nil {
return
}
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
modelPrice, success := common.GetModelPrice(modelName, true)
if !success {
defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
// 预扣
groupRatio := common.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
return
}
quota := int(ratio * common.QuotaPerUnit)
if userQuota-quota < 0 {
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
return
}
if relayInfo.OriginTaskID != "" {
originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
if originTask.ChannelId != relayInfo.ChannelId {
channel, err := model.GetChannelById(originTask.ChannelId, true)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
return
}
if channel.Status != common.ChannelStatusEnabled {
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
}
c.Set("base_url", channel.GetBaseURL())
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
relayInfo.BaseUrl = channel.GetBaseURL()
relayInfo.ChannelId = originTask.ChannelId
}
}
// build body
requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
return
}
// do request
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
return
}
// handle response
if resp != nil && resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
return
}
defer func(ctx context.Context) {
// release quota
if relayInfo.ConsumeQuota && taskErr == nil {
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(relayInfo.UserId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
}
}(c.Request.Context())
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
if taskErr != nil {
return
}
relayInfo.ConsumeQuota = true
// insert task
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
task.TaskID = taskID
task.Quota = quota
task.Data = taskData
err = task.Insert()
if err != nil {
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
return
}
return nil
}
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
}
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
respBuilder, ok := fetchRespBuilders[relayMode]
if !ok {
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
}
respBody, taskErr := respBuilder(c)
if taskErr != nil {
return taskErr
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
}
return
}
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
userId := c.GetInt("id")
var condition = struct {
IDs []any `json:"ids"`
Action string `json:"action"`
}{}
err := c.BindJSON(&condition)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
return
}
var tasks []any
if len(condition.IDs) > 0 {
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
return
}
for _, task := range taskModels {
tasks = append(tasks, TaskModel2Dto(task))
}
} else {
tasks = make([]any, 0)
}
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
Code: "success",
Data: tasks,
})
return
}
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask, exist, err := model.GetByTaskId(userId, taskId)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
return
}
if !exist {
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{
Code: "success",
Data: TaskModel2Dto(originTask),
})
return
}
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
return &dto.TaskDto{
TaskID: task.TaskID,
Action: task.Action,
Status: string(task.Status),
FailReason: task.FailReason,
SubmitTime: task.SubmitTime,
StartTime: task.StartTime,
FinishTime: task.FinishTime,
Progress: task.Progress,
Data: task.Data,
}
}

View File

@@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) {
relayMjModeRouter := router.Group("/:mode/mj")
registerMjRouterGroup(relayMjModeRouter)
//relayMjRouter.Use()
relaySunoRouter := router.Group("/suno")
relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
{
relaySunoRouter.POST("/submit/:action", controller.RelayTask)
relaySunoRouter.POST("/fetch", controller.RelayTask)
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
}
}
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {

View File

@@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping
openaiErr.StatusCode = intCode
}
}
func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
openaiErr := TaskErrorWrapper(err, code, statusCode)
openaiErr.LocalError = true
return openaiErr
}
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
text := err.Error()
// 定义一个正则表达式匹配URL
if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
common.SysLog(fmt.Sprintf("error: %s", text))
text = "请求上游地址失败"
}
//避免暴露内部错误
taskError := &dto.TaskError{
Code: code,
Message: text,
StatusCode: statusCode,
Error: err,
}
return taskError
}

10
service/task.go Normal file
View File

@@ -0,0 +1,10 @@
package service
import (
"one-api/constant"
"strings"
)
func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
}