From ab30f584cc94537b5a724a049a53ad036f8757c7 Mon Sep 17 00:00:00 2001 From: Seefs <40468931+seefs001@users.noreply.github.com> Date: Fri, 31 Oct 2025 16:51:05 +0800 Subject: [PATCH] feat: add ali wan video (#2141) * feat: add ali wan video * refactor: use same UnmarshalBodyReusable * feat: enhance request body metadata * feat: opt wan convertToOpenAIVideo * feat: add wan support other param via json metadata * refactor: remove unused code * fix ali --------- Co-authored-by: feitianbubu --- common/gin.go | 33 +-- controller/video_proxy.go | 8 +- dto/openai_video.go | 2 +- model/task.go | 20 +- relay/channel/task/ali/adaptor.go | 360 ++++++++++++++++++++++++++++ relay/channel/task/ali/constants.go | 11 + relay/common/relay_info.go | 51 +++- relay/common/relay_utils.go | 63 ++--- relay/relay_adaptor.go | 3 + 9 files changed, 475 insertions(+), 76 deletions(-) create mode 100644 relay/channel/task/ali/adaptor.go create mode 100644 relay/channel/task/ali/constants.go diff --git a/common/gin.go b/common/gin.go index e8d8bda3a..cc7164e4b 100644 --- a/common/gin.go +++ b/common/gin.go @@ -2,7 +2,6 @@ package common import ( "bytes" - "encoding/json" "io" "mime/multipart" "net/http" @@ -41,11 +40,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { //} contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = Unmarshal(requestBody, &v) + err = Unmarshal(requestBody, v) } else if strings.Contains(contentType, gin.MIMEPOSTForm) { - err = parseFormData(requestBody, &v) + err = parseFormData(requestBody, v) } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) { - err = parseMultipartFormData(c, requestBody, &v) + err = parseMultipartFormData(c, requestBody, v) } else { // skip for now // TODO: someday non json request have variant model, we will need to implementation this @@ -145,6 +144,20 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { return form, nil } +func processFormMap(formMap map[string]any, v any) error { + jsonData, err := Marshal(formMap) + if err != nil { + return err + } + + err = Unmarshal(jsonData, v) + if err != nil { + return err + } + + return nil +} + func parseFormData(data []byte, v any) error { values, err := url.ParseQuery(string(data)) if err != nil { @@ -158,12 +171,8 @@ func parseFormData(data []byte, v any) error { formMap[key] = vals } } - jsonData, err := json.Marshal(formMap) - if err != nil { - return err - } - return Unmarshal(jsonData, v) + return processFormMap(formMap, v) } func parseMultipartFormData(c *gin.Context, data []byte, v any) error { @@ -191,10 +200,6 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error { formMap[key] = vals } } - jsonData, err := Marshal(formMap) - if err != nil { - return err - } - return Unmarshal(jsonData, v) + return processFormMap(formMap, v) } diff --git a/controller/video_proxy.go b/controller/video_proxy.go index c9801b4b5..829a94c1b 100644 --- a/controller/video_proxy.go +++ b/controller/video_proxy.go @@ -91,7 +91,8 @@ func VideoProxy(c *gin.Context) { return } - if channel.Type == constant.ChannelTypeGemini { + switch channel.Type { + case constant.ChannelTypeGemini: apiKey := task.PrivateData.Key if apiKey == "" { logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID)) @@ -116,7 +117,10 @@ func VideoProxy(c *gin.Context) { return } req.Header.Set("x-goog-api-key", apiKey) - } else { + case constant.ChannelTypeAli: + // Video URL is directly in task.FailReason + videoURL = task.FailReason + default: // Default (Sora, etc.): Use original logic videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID) req.Header.Set("Authorization", "Bearer "+channel.Key) diff --git a/dto/openai_video.go b/dto/openai_video.go index 051769b98..bee64c31f 100644 --- a/dto/openai_video.go +++ b/dto/openai_video.go @@ -27,7 +27,7 @@ type OpenAIVideo struct { Size string `json:"size,omitempty"` RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"` Error *OpenAIVideoError `json:"error,omitempty"` - Metadata map[string]any `json:"meta_data,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } func (m *OpenAIVideo) SetProgressStr(progress string) { diff --git a/model/task.go b/model/task.go index 994dd25c7..4e5b4193b 100644 --- a/model/task.go +++ b/model/task.go @@ -73,20 +73,22 @@ func (t *Task) GetData(v any) error { } type Properties struct { - Input string `json:"input"` + Input string `json:"input"` + UpstreamModelName string `json:"upstream_model_name,omitempty"` + OriginModelName string `json:"origin_model_name,omitempty"` } func (m *Properties) Scan(val interface{}) error { bytesValue, _ := val.([]byte) if len(bytesValue) == 0 { - m.Input = "" + *m = Properties{} return nil } return json.Unmarshal(bytesValue, m) } func (m Properties) Value() (driver.Value, error) { - if m.Input == "" { + if m == (Properties{}) { return nil, nil } return json.Marshal(m) @@ -127,8 +129,16 @@ type SyncTaskQueryParams struct { func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { properties := Properties{} privateData := TaskPrivateData{} - if relayInfo != nil && relayInfo.ChannelMeta != nil && relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini { - privateData.Key = relayInfo.ChannelMeta.ApiKey + if relayInfo != nil && relayInfo.ChannelMeta != nil { + if relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini { + privateData.Key = relayInfo.ChannelMeta.ApiKey + } + if relayInfo.UpstreamModelName != "" { + properties.UpstreamModelName = relayInfo.UpstreamModelName + } + if relayInfo.OriginModelName != "" { + properties.OriginModelName = relayInfo.OriginModelName + } } t := &Task{ diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go new file mode 100644 index 000000000..a40d343e5 --- /dev/null +++ b/relay/channel/task/ali/adaptor.go @@ -0,0 +1,360 @@ +package ali + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay/channel" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/service" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +// ============================ +// Request / Response structures +// ============================ + +// AliVideoRequest 阿里通义万相视频生成请求 +type AliVideoRequest struct { + Model string `json:"model"` + Input AliVideoInput `json:"input"` + Parameters *AliVideoParameters `json:"parameters,omitempty"` +} + +// AliVideoInput 视频输入参数 +type AliVideoInput struct { + Prompt string `json:"prompt,omitempty"` // 文本提示词 + ImgURL string `json:"img_url,omitempty"` // 首帧图像URL或Base64(图生视频) + FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频) + LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频) + AudioURL string `json:"audio_url,omitempty"` // 音频URL(wan2.5支持) + NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词 + Template string `json:"template,omitempty"` // 视频特效模板 +} + +// AliVideoParameters 视频参数 +type AliVideoParameters struct { + Resolution string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P(图生视频、首尾帧生视频) + Size string `json:"size,omitempty"` // 尺寸: 如 "832*480"(文生视频) + Duration int `json:"duration,omitempty"` // 时长: 3-10秒 + PromptExtend bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写 + Watermark bool `json:"watermark,omitempty"` // 是否添加水印 + Audio *bool `json:"audio,omitempty"` // 是否添加音频(wan2.5) + Seed int `json:"seed,omitempty"` // 随机数种子 +} + +// AliVideoResponse 阿里通义万相响应 +type AliVideoResponse struct { + Output AliVideoOutput `json:"output"` + RequestID string `json:"request_id"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Usage *AliUsage `json:"usage,omitempty"` +} + +// AliVideoOutput 输出信息 +type AliVideoOutput struct { + TaskID string `json:"task_id"` + TaskStatus string `json:"task_status"` + SubmitTime string `json:"submit_time,omitempty"` + ScheduledTime string `json:"scheduled_time,omitempty"` + EndTime string `json:"end_time,omitempty"` + OrigPrompt string `json:"orig_prompt,omitempty"` + ActualPrompt string `json:"actual_prompt,omitempty"` + VideoURL string `json:"video_url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +// AliUsage 使用统计 +type AliUsage struct { + Duration int `json:"duration,omitempty"` + VideoCount int `json:"video_count,omitempty"` + SR int `json:"SR,omitempty"` +} + +type AliMetadata struct { + // Input 相关 + AudioURL string `json:"audio_url,omitempty"` // 音频URL + ImgURL string `json:"img_url,omitempty"` // 图片URL(图生视频) + FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频) + LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频) + NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词 + Template string `json:"template,omitempty"` // 视频特效模板 + + // Parameters 相关 + Resolution *string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P + Size *string `json:"size,omitempty"` // 尺寸: 如 "832*480" + Duration *int `json:"duration,omitempty"` // 时长 + PromptExtend *bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写 + Watermark *bool `json:"watermark,omitempty"` // 是否添加水印 + Audio *bool `json:"audio,omitempty"` // 是否添加音频 + Seed *int `json:"seed,omitempty"` // 随机数种子 +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + // 阿里通义万相支持 JSON 格式,不使用 multipart + return relaycommon.ValidateMultipartDirect(c, info) +} + +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/api/v1/services/aigc/video-generation/video-synthesis", a.baseURL), nil +} + +// BuildRequestHeader sets required headers for Ali API +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + req.Header.Set("Authorization", "Bearer "+a.apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-DashScope-Async", "enable") // 阿里异步任务必须设置 + return nil +} + +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { + var taskReq relaycommon.TaskSubmitReq + if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil { + return nil, errors.Wrap(err, "unmarshal_task_request_failed") + } + aliReq := a.convertToAliRequest(taskReq) + + bodyBytes, err := common.Marshal(aliReq) + if err != nil { + return nil, errors.Wrap(err, "marshal_ali_request_failed") + } + + return bytes.NewReader(bodyBytes), nil +} + +func (a *TaskAdaptor) convertToAliRequest(req relaycommon.TaskSubmitReq) *AliVideoRequest { + aliReq := &AliVideoRequest{ + Model: req.Model, + Input: AliVideoInput{ + Prompt: req.Prompt, + ImgURL: req.InputReference, + }, + Parameters: &AliVideoParameters{ + PromptExtend: true, // 默认开启智能改写 + Watermark: false, + }, + } + + // 处理分辨率映射 + if req.Size != "" { + resolution := strings.ToUpper(req.Size) + // 支持 480p, 720p, 1080p 或 480P, 720P, 1080P + if !strings.HasSuffix(resolution, "P") { + resolution = resolution + "P" + } + aliReq.Parameters.Resolution = resolution + } else { + // 根据模型设置默认分辨率 + if strings.HasPrefix(req.Model, "wan2.5") { + aliReq.Parameters.Resolution = "1080P" + } else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") { + aliReq.Parameters.Resolution = "720P" + } else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") { + aliReq.Parameters.Resolution = "1080P" + } else { + aliReq.Parameters.Resolution = "720P" + } + } + + // 处理时长 + if req.Duration > 0 { + aliReq.Parameters.Duration = req.Duration + } else { + aliReq.Parameters.Duration = 5 // 默认5秒 + } + + // 从 metadata 中提取额外参数 + if req.Metadata != nil { + if metadataBytes, err := common.Marshal(req.Metadata); err == nil { + _ = common.Unmarshal(metadataBytes, aliReq) + } + } + + return aliReq +} + +// DoRequest delegates to common helper +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + _ = resp.Body.Close() + + // 解析阿里响应 + var aliResp AliVideoResponse + if err := common.Unmarshal(responseBody, &aliResp); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + // 检查错误 + if aliResp.Code != "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("%s: %s", aliResp.Code, aliResp.Message), "ali_api_error", resp.StatusCode) + return + } + + if aliResp.Output.TaskID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return + } + + // 转换为 OpenAI 格式响应 + openAIResp := dto.NewOpenAIVideo() + openAIResp.ID = aliResp.Output.TaskID + openAIResp.Model = c.GetString("model") + if openAIResp.Model == "" && info != nil { + openAIResp.Model = info.OriginModelName + } + openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus) + openAIResp.CreatedAt = common.GetTimestamp() + + // 返回 OpenAI 格式 + c.JSON(http.StatusOK, openAIResp) + + return aliResp.Output.TaskID, responseBody, nil +} + +// FetchTask 查询任务状态 +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + uri := fmt.Sprintf("%s/api/v1/tasks/%s", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, uri, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+key) + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return ModelList +} + +func (a *TaskAdaptor) GetChannelName() string { + return ChannelName +} + +// ParseTaskResult 解析任务结果 +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + var aliResp AliVideoResponse + if err := common.Unmarshal(respBody, &aliResp); err != nil { + return nil, errors.Wrap(err, "unmarshal task result failed") + } + + taskResult := relaycommon.TaskInfo{ + Code: 0, + } + + // 状态映射 + switch aliResp.Output.TaskStatus { + case "PENDING": + taskResult.Status = model.TaskStatusQueued + case "RUNNING": + taskResult.Status = model.TaskStatusInProgress + case "SUCCEEDED": + taskResult.Status = model.TaskStatusSuccess + // 阿里直接返回视频URL,不需要额外的代理端点 + taskResult.Url = aliResp.Output.VideoURL + case "FAILED", "CANCELED", "UNKNOWN": + taskResult.Status = model.TaskStatusFailure + if aliResp.Message != "" { + taskResult.Reason = aliResp.Message + } else if aliResp.Output.Message != "" { + taskResult.Reason = fmt.Sprintf("task failed, code: %s , message: %s", aliResp.Output.Code, aliResp.Output.Message) + } else { + taskResult.Reason = "task failed" + } + default: + taskResult.Status = model.TaskStatusQueued + } + + return &taskResult, nil +} + +func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { + var aliResp AliVideoResponse + if err := common.Unmarshal(task.Data, &aliResp); err != nil { + return nil, errors.Wrap(err, "unmarshal ali response failed") + } + + openAIResp := dto.NewOpenAIVideo() + openAIResp.ID = task.TaskID + openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus) + openAIResp.Model = task.Properties.OriginModelName + openAIResp.SetProgressStr(task.Progress) + openAIResp.CreatedAt = task.CreatedAt + openAIResp.CompletedAt = task.UpdatedAt + + // 设置视频URL(核心字段) + openAIResp.SetMetadata("url", aliResp.Output.VideoURL) + + // 错误处理 + if aliResp.Code != "" { + openAIResp.Error = &dto.OpenAIVideoError{ + Code: aliResp.Code, + Message: aliResp.Message, + } + } else if aliResp.Output.Code != "" { + openAIResp.Error = &dto.OpenAIVideoError{ + Code: aliResp.Output.Code, + Message: aliResp.Output.Message, + } + } + + return common.Marshal(openAIResp) +} + +func convertAliStatus(aliStatus string) string { + switch aliStatus { + case "PENDING": + return dto.VideoStatusQueued + case "RUNNING": + return dto.VideoStatusInProgress + case "SUCCEEDED": + return dto.VideoStatusCompleted + case "FAILED", "CANCELED", "UNKNOWN": + return dto.VideoStatusFailed + default: + return dto.VideoStatusUnknown + } +} diff --git a/relay/channel/task/ali/constants.go b/relay/channel/task/ali/constants.go new file mode 100644 index 000000000..8dc64ec59 --- /dev/null +++ b/relay/channel/task/ali/constants.go @@ -0,0 +1,11 @@ +package ali + +var ModelList = []string{ + "wan2.5-i2v-preview", // 万相2.5 preview(有声视频)推荐 + "wan2.2-i2v-flash", // 万相2.2极速版(无声视频) + "wan2.2-i2v-plus", // 万相2.2专业版(无声视频) + "wanx2.1-i2v-plus", // 万相2.1专业版(无声视频) + "wanx2.1-i2v-turbo", // 万相2.1极速版(无声视频) +} + +var ChannelName = "ali" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b67c4143d..10601298c 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -1,6 +1,7 @@ package common import ( + "encoding/json" "errors" "fmt" "strings" @@ -485,14 +486,16 @@ type TaskRelayInfo struct { } type TaskSubmitReq struct { - Prompt string `json:"prompt"` - Model string `json:"model,omitempty"` - Mode string `json:"mode,omitempty"` - Image string `json:"image,omitempty"` - Images []string `json:"images,omitempty"` - Size string `json:"size,omitempty"` - Duration int `json:"duration,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` + Prompt string `json:"prompt"` + Model string `json:"model,omitempty"` + Mode string `json:"mode,omitempty"` + Image string `json:"image,omitempty"` + Images []string `json:"images,omitempty"` + Size string `json:"size,omitempty"` + Duration int `json:"duration,omitempty"` + Seconds string `json:"seconds,omitempty"` + InputReference string `json:"input_reference,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` } func (t TaskSubmitReq) GetPrompt() string { @@ -503,6 +506,38 @@ func (t TaskSubmitReq) HasImage() bool { return len(t.Images) > 0 } +func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error { + type Alias TaskSubmitReq + aux := &struct { + Metadata json.RawMessage `json:"metadata,omitempty"` + *Alias + }{ + Alias: (*Alias)(t), + } + + if err := common.Unmarshal(data, &aux); err != nil { + return err + } + + if len(aux.Metadata) > 0 { + var metadataStr string + if err := common.Unmarshal(aux.Metadata, &metadataStr); err == nil && metadataStr != "" { + var metadataObj map[string]interface{} + if err := common.Unmarshal([]byte(metadataStr), &metadataObj); err == nil { + t.Metadata = metadataObj + return nil + } + } + + var metadataObj map[string]interface{} + if err := common.Unmarshal(aux.Metadata, &metadataObj); err == nil { + t.Metadata = metadataObj + } + } + + return nil +} + type TaskInfo struct { Code int `json:"code"` TaskID string `json:"task_id"` diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index b38baf13a..1cb2b9863 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -108,62 +108,33 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string } func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { - contentType := c.GetHeader("Content-Type") var prompt string var model string var seconds int var size string var hasInputReference bool - if strings.HasPrefix(contentType, "multipart/form-data") { - form, err := common.ParseMultipartFormReusable(c) - if err != nil { - return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true) - } - defer form.RemoveAll() + var req TaskSubmitReq + if err := common.UnmarshalBodyReusable(c, &req); err != nil { + return createTaskError(err, "invalid_json", http.StatusBadRequest, true) + } - prompts, ok := form.Value["prompt"] - if !ok || len(prompts) == 0 { - return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true) - } - prompt = prompts[0] - - if _, ok := form.Value["model"]; !ok { - return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) - } - model = form.Value["model"][0] - - if _, ok := form.File["input_reference"]; ok { - hasInputReference = true - } - - if ss, ok := form.Value["seconds"]; ok { - sInt := common.String2Int(ss[0]) - if sInt > seconds { - seconds = common.String2Int(ss[0]) - } - } - - if sz, ok := form.Value["size"]; ok { - size = sz[0] - } - } else { - var req TaskSubmitReq - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - return createTaskError(err, "invalid_json", http.StatusBadRequest, true) - } - - prompt = req.Prompt - model = req.Model + prompt = req.Prompt + model = req.Model + seconds, _ = strconv.Atoi(req.Seconds) + if seconds == 0 { seconds = req.Duration + } + if req.InputReference != "" { + req.Images = []string{req.InputReference} + } - if strings.TrimSpace(req.Model) == "" { - return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) - } + if strings.TrimSpace(req.Model) == "" { + return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true) + } - if req.HasImage() { - hasInputReference = true - } + if req.HasImage() { + hasInputReference = true } if taskErr := validatePrompt(prompt); taskErr != nil { diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 6edb9a8cf..85a0b6396 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -28,6 +28,7 @@ import ( "github.com/QuantumNous/new-api/relay/channel/perplexity" "github.com/QuantumNous/new-api/relay/channel/siliconflow" "github.com/QuantumNous/new-api/relay/channel/submodel" + taskali "github.com/QuantumNous/new-api/relay/channel/task/ali" taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao" taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini" taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng" @@ -133,6 +134,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { } if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil { switch channelType { + case constant.ChannelTypeAli: + return &taskali.TaskAdaptor{} case constant.ChannelTypeKling: return &kling.TaskAdaptor{} case constant.ChannelTypeJimeng: