diff --git a/controller/task_video.go b/controller/task_video.go index c39593507..8c9f9719e 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -52,6 +52,7 @@ func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, cha info.ChannelMeta = &relaycommon.ChannelMeta{ ChannelBaseUrl: cacheGetChannel.GetBaseURL(), } + info.ApiKey = cacheGetChannel.Key adaptor.Init(info) for _, taskId := range taskIds { if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { diff --git a/model/task.go b/model/task.go index 4e5b4193b..c76d26edf 100644 --- a/model/task.go +++ b/model/task.go @@ -429,3 +429,14 @@ func TaskCountAllUserTask(userId int, queryParams SyncTaskQueryParams) int64 { _ = query.Count(&total).Error return total } +func (t *Task) ToOpenAIVideo() *dto.OpenAIVideo { + openAIVideo := dto.NewOpenAIVideo() + openAIVideo.ID = t.TaskID + openAIVideo.Status = t.Status.ToVideoStatus() + openAIVideo.Model = t.Properties.OriginModelName + openAIVideo.SetProgressStr(t.Progress) + openAIVideo.CreatedAt = t.CreatedAt + openAIVideo.CompletedAt = t.UpdatedAt + openAIVideo.SetMetadata("url", t.FailReason) + return openAIVideo +} diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go new file mode 100644 index 000000000..cb6f1eebd --- /dev/null +++ b/relay/channel/task/hailuo/adaptor.go @@ -0,0 +1,297 @@ +package hailuo + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/relay/channel" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/service" +) + +// https://platform.minimaxi.com/docs/api-reference/video-generation-intro +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s%s", a.baseURL, TextToVideoEndpoint), nil +} + +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+a.apiKey) + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req, ok := v.(relaycommon.TaskSubmitReq) + if !ok { + return nil, fmt.Errorf("invalid request type in context") + } + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, errors.Wrap(err, "convert request payload failed") + } + + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + + return bytes.NewReader(data), nil +} + +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + _ = resp.Body.Close() + + var hResp VideoResponse + if err := json.Unmarshal(responseBody, &hResp); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if hResp.BaseResp.StatusCode != StatusSuccess { + taskErr = service.TaskErrorWrapper( + fmt.Errorf("hailuo api error: %s", hResp.BaseResp.StatusMsg), + strconv.Itoa(hResp.BaseResp.StatusCode), + http.StatusBadRequest, + ) + return + } + + ov := dto.NewOpenAIVideo() + ov.ID = hResp.TaskID + ov.TaskID = hResp.TaskID + ov.CreatedAt = time.Now().Unix() + ov.Model = info.OriginModelName + + c.JSON(http.StatusOK, ov) + return hResp.TaskID, responseBody, nil +} + +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + uri := fmt.Sprintf("%s%s?task_id=%s", baseUrl, QueryTaskEndpoint, taskID) + + req, err := http.NewRequest(http.MethodGet, uri, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+key) + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return ModelList +} + +func (a *TaskAdaptor) GetChannelName() string { + return ChannelName +} + +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*VideoRequest, error) { + modelConfig := GetModelConfig(req.Model) + duration := DefaultDuration + if req.Duration > 0 { + duration = req.Duration + } + resolution := modelConfig.DefaultResolution + if req.Size != "" { + resolution = a.parseResolutionFromSize(req.Size, modelConfig) + } + + videoRequest := &VideoRequest{ + Model: req.Model, + Prompt: req.Prompt, + Duration: &duration, + Resolution: resolution, + } + if err := req.UnmarshalMetadata(&videoRequest); err != nil { + return nil, errors.Wrap(err, "unmarshal metadata to video request failed") + } + + return videoRequest, nil +} + +func (a *TaskAdaptor) parseResolutionFromSize(size string, modelConfig ModelConfig) string { + switch { + case strings.Contains(size, "1080"): + return Resolution1080P + case strings.Contains(size, "768"): + return Resolution768P + case strings.Contains(size, "720"): + return Resolution720P + case strings.Contains(size, "512"): + return Resolution512P + default: + return modelConfig.DefaultResolution + } +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resTask := QueryTaskResponse{} + if err := json.Unmarshal(respBody, &resTask); err != nil { + return nil, errors.Wrap(err, "unmarshal task result failed") + } + + taskResult := relaycommon.TaskInfo{} + + if resTask.BaseResp.StatusCode == StatusSuccess { + taskResult.Code = 0 + } else { + taskResult.Code = resTask.BaseResp.StatusCode + taskResult.Reason = resTask.BaseResp.StatusMsg + taskResult.Status = model.TaskStatusFailure + taskResult.Progress = "100%" + } + + switch resTask.Status { + case TaskStatusPreparing, TaskStatusQueueing, TaskStatusProcessing: + taskResult.Status = model.TaskStatusInProgress + taskResult.Progress = "30%" + if resTask.Status == TaskStatusProcessing { + taskResult.Progress = "50%" + } + case TaskStatusSuccess: + taskResult.Status = model.TaskStatusSuccess + taskResult.Progress = "100%" + taskResult.Url = a.buildVideoURL(resTask.TaskID, resTask.FileID) + case TaskStatusFailed: + taskResult.Status = model.TaskStatusFailure + taskResult.Progress = "100%" + if taskResult.Reason == "" { + taskResult.Reason = "task failed" + } + default: + taskResult.Status = model.TaskStatusInProgress + taskResult.Progress = "30%" + } + + return &taskResult, nil +} + +func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) { + var hailuoResp QueryTaskResponse + if err := json.Unmarshal(originTask.Data, &hailuoResp); err != nil { + return nil, errors.Wrap(err, "unmarshal hailuo task data failed") + } + + openAIVideo := originTask.ToOpenAIVideo() + if hailuoResp.BaseResp.StatusCode != StatusSuccess { + openAIVideo.Error = &dto.OpenAIVideoError{ + Message: hailuoResp.BaseResp.StatusMsg, + Code: strconv.Itoa(hailuoResp.BaseResp.StatusCode), + } + } + + jsonData, err := common.Marshal(openAIVideo) + if err != nil { + return nil, errors.Wrap(err, "marshal openai video failed") + } + + return jsonData, nil +} + +func (a *TaskAdaptor) buildVideoURL(_, fileID string) string { + if a.apiKey == "" || a.baseURL == "" { + return "" + } + + url := fmt.Sprintf("%s/v1/files/retrieve?file_id=%s", a.baseURL, fileID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "" + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+a.apiKey) + + resp, err := service.GetHttpClient().Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "" + } + + var retrieveResp RetrieveFileResponse + if err := json.Unmarshal(responseBody, &retrieveResp); err != nil { + return "" + } + + if retrieveResp.BaseResp.StatusCode != StatusSuccess { + return "" + } + + return retrieveResp.File.DownloadURL +} + +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +func containsInt(slice []int, item int) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} diff --git a/relay/channel/task/hailuo/constants.go b/relay/channel/task/hailuo/constants.go new file mode 100644 index 000000000..5e5408637 --- /dev/null +++ b/relay/channel/task/hailuo/constants.go @@ -0,0 +1,52 @@ +package hailuo + +const ( + ChannelName = "hailuo-video" +) + +var ModelList = []string{ + "MiniMax-Hailuo-2.3", + "MiniMax-Hailuo-2.3-Fast", + "MiniMax-Hailuo-02", + "T2V-01-Director", + "T2V-01", + "I2V-01-Director", + "I2V-01-live", + "I2V-01", + "S2V-01", +} + +const ( + TextToVideoEndpoint = "/v1/video_generation" + QueryTaskEndpoint = "/v1/query/video_generation" +) + +const ( + StatusSuccess = 0 + StatusRateLimit = 1002 + StatusAuthFailed = 1004 + StatusNoBalance = 1008 + StatusSensitive = 1026 + StatusParamError = 2013 + StatusInvalidKey = 2049 +) + +const ( + TaskStatusPreparing = "Preparing" + TaskStatusQueueing = "Queueing" + TaskStatusProcessing = "Processing" + TaskStatusSuccess = "Success" + TaskStatusFailed = "Fail" +) + +const ( + Resolution512P = "512P" + Resolution720P = "720P" + Resolution768P = "768P" + Resolution1080P = "1080P" +) + +const ( + DefaultDuration = 6 + DefaultResolution = Resolution720P +) diff --git a/relay/channel/task/hailuo/models.go b/relay/channel/task/hailuo/models.go new file mode 100644 index 000000000..09a97766f --- /dev/null +++ b/relay/channel/task/hailuo/models.go @@ -0,0 +1,170 @@ +package hailuo + +type SubjectReference struct { + Type string `json:"type"` // Subject type, currently only supports "character" + Image []string `json:"image"` // Array of subject reference images (currently only supports single image) +} + +type VideoRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + PromptOptimizer *bool `json:"prompt_optimizer,omitempty"` + FastPretreatment *bool `json:"fast_pretreatment,omitempty"` + Duration *int `json:"duration,omitempty"` + Resolution string `json:"resolution,omitempty"` + CallbackURL string `json:"callback_url,omitempty"` + AigcWatermark *bool `json:"aigc_watermark,omitempty"` + FirstFrameImage string `json:"first_frame_image,omitempty"` // For image-to-video and start-end-to-video + LastFrameImage string `json:"last_frame_image,omitempty"` // For start-end-to-video + SubjectReference []SubjectReference `json:"subject_reference,omitempty"` // For subject-reference-to-video +} + +type VideoResponse struct { + TaskID string `json:"task_id"` + BaseResp BaseResp `json:"base_resp"` +} + +type BaseResp struct { + StatusCode int `json:"status_code"` + StatusMsg string `json:"status_msg"` +} + +type QueryTaskRequest struct { + TaskID string `json:"task_id"` +} + +type QueryTaskResponse struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + FileID string `json:"file_id,omitempty"` + VideoWidth int `json:"video_width,omitempty"` + VideoHeight int `json:"video_height,omitempty"` + BaseResp BaseResp `json:"base_resp"` +} + +type ErrorInfo struct { + StatusCode int `json:"status_code"` + StatusMsg string `json:"status_msg"` +} + +type TaskStatusInfo struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + FileID string `json:"file_id,omitempty"` + VideoURL string `json:"video_url,omitempty"` + ErrorCode int `json:"error_code,omitempty"` + ErrorMsg string `json:"error_msg,omitempty"` +} + +type ModelConfig struct { + Name string + DefaultResolution string + SupportedDurations []int + SupportedResolutions []string + HasPromptOptimizer bool + HasFastPretreatment bool +} + +type RetrieveFileResponse struct { + File FileObject `json:"file"` + BaseResp BaseResp `json:"base_resp"` +} + +type FileObject struct { + FileID int64 `json:"file_id"` + Bytes int64 `json:"bytes"` + CreatedAt int64 `json:"created_at"` + Filename string `json:"filename"` + Purpose string `json:"purpose"` + DownloadURL string `json:"download_url"` +} + +func GetModelConfig(model string) ModelConfig { + configs := map[string]ModelConfig{ + "MiniMax-Hailuo-2.3": { + Name: "MiniMax-Hailuo-2.3", + DefaultResolution: Resolution768P, + SupportedDurations: []int{6, 10}, + SupportedResolutions: []string{Resolution768P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: true, + }, + "MiniMax-Hailuo-2.3-Fast": { + Name: "MiniMax-Hailuo-2.3-Fast", + DefaultResolution: Resolution768P, + SupportedDurations: []int{6, 10}, + SupportedResolutions: []string{Resolution768P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: true, + }, + "MiniMax-Hailuo-02": { + Name: "MiniMax-Hailuo-02", + DefaultResolution: Resolution768P, + SupportedDurations: []int{6, 10}, + SupportedResolutions: []string{Resolution512P, Resolution768P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: true, + }, + "T2V-01-Director": { + Name: "T2V-01-Director", + DefaultResolution: Resolution768P, + SupportedDurations: []int{6}, + SupportedResolutions: []string{Resolution768P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + }, + "T2V-01": { + Name: "T2V-01", + DefaultResolution: Resolution720P, + SupportedDurations: []int{6}, + SupportedResolutions: []string{Resolution720P}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + }, + "I2V-01-Director": { + Name: "I2V-01-Director", + DefaultResolution: Resolution720P, + SupportedDurations: []int{6}, + SupportedResolutions: []string{Resolution720P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + }, + "I2V-01-live": { + Name: "I2V-01-live", + DefaultResolution: Resolution720P, + SupportedDurations: []int{6}, + SupportedResolutions: []string{Resolution720P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + }, + "I2V-01": { + Name: "I2V-01", + DefaultResolution: Resolution720P, + SupportedDurations: []int{6}, + SupportedResolutions: []string{Resolution720P, Resolution1080P}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + }, + "S2V-01": { + Name: "S2V-01", + DefaultResolution: Resolution720P, + SupportedDurations: []int{6}, + SupportedResolutions: []string{Resolution720P}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + }, + } + + if config, exists := configs[model]; exists { + return config + } + + return ModelConfig{ + Name: model, + DefaultResolution: DefaultResolution, + SupportedDurations: []int{6}, + SupportedResolutions: []string{DefaultResolution}, + HasPromptOptimizer: true, + HasFastPretreatment: false, + } +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 10601298c..33ef4d14c 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -498,11 +498,11 @@ type TaskSubmitReq struct { Metadata map[string]interface{} `json:"metadata,omitempty"` } -func (t TaskSubmitReq) GetPrompt() string { +func (t *TaskSubmitReq) GetPrompt() string { return t.Prompt } -func (t TaskSubmitReq) HasImage() bool { +func (t *TaskSubmitReq) HasImage() bool { return len(t.Images) > 0 } @@ -537,6 +537,20 @@ func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error { return nil } +func (t *TaskSubmitReq) UnmarshalMetadata(v any) error { + metadata := t.Metadata + if metadata != nil { + metadataBytes, err := json.Marshal(metadata) + if err != nil { + return fmt.Errorf("marshal metadata failed: %w", err) + } + err = json.Unmarshal(metadataBytes, v) + if err != nil { + return fmt.Errorf("unmarshal metadata to target failed: %w", err) + } + } + return nil +} type TaskInfo struct { Code int `json:"code"` diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 55afea179..b838b313d 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -32,6 +32,7 @@ import ( taskali "github.com/QuantumNous/new-api/relay/channel/task/ali" taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao" taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini" + "github.com/QuantumNous/new-api/relay/channel/task/hailuo" taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng" "github.com/QuantumNous/new-api/relay/channel/task/kling" tasksora "github.com/QuantumNous/new-api/relay/channel/task/sora" @@ -153,6 +154,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { return &tasksora.TaskAdaptor{} case constant.ChannelTypeGemini: return &taskGemini.TaskAdaptor{} + case constant.ChannelTypeMiniMax: + return &hailuo.TaskAdaptor{} } } return nil