mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
feat: add ali wan video (#2141)
* feat: add ali wan video * refactor: use same UnmarshalBodyReusable * feat: enhance request body metadata * feat: opt wan convertToOpenAIVideo * feat: add wan support other param via json metadata * refactor: remove unused code * fix ali --------- Co-authored-by: feitianbubu <feitianbubu@qq.com>
This commit is contained in:
@@ -2,7 +2,6 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
@@ -41,11 +40,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
//}
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = Unmarshal(requestBody, &v)
|
||||
err = Unmarshal(requestBody, v)
|
||||
} else if strings.Contains(contentType, gin.MIMEPOSTForm) {
|
||||
err = parseFormData(requestBody, &v)
|
||||
err = parseFormData(requestBody, v)
|
||||
} else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) {
|
||||
err = parseMultipartFormData(c, requestBody, &v)
|
||||
err = parseMultipartFormData(c, requestBody, v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
@@ -145,6 +144,20 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
||||
return form, nil
|
||||
}
|
||||
|
||||
func processFormMap(formMap map[string]any, v any) error {
|
||||
jsonData, err := Marshal(formMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = Unmarshal(jsonData, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFormData(data []byte, v any) error {
|
||||
values, err := url.ParseQuery(string(data))
|
||||
if err != nil {
|
||||
@@ -158,12 +171,8 @@ func parseFormData(data []byte, v any) error {
|
||||
formMap[key] = vals
|
||||
}
|
||||
}
|
||||
jsonData, err := json.Marshal(formMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Unmarshal(jsonData, v)
|
||||
return processFormMap(formMap, v)
|
||||
}
|
||||
|
||||
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
||||
@@ -191,10 +200,6 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
||||
formMap[key] = vals
|
||||
}
|
||||
}
|
||||
jsonData, err := Marshal(formMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Unmarshal(jsonData, v)
|
||||
return processFormMap(formMap, v)
|
||||
}
|
||||
|
||||
@@ -91,7 +91,8 @@ func VideoProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if channel.Type == constant.ChannelTypeGemini {
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeGemini:
|
||||
apiKey := task.PrivateData.Key
|
||||
if apiKey == "" {
|
||||
logger.LogError(c.Request.Context(), fmt.Sprintf("Missing stored API key for Gemini task %s", taskID))
|
||||
@@ -116,7 +117,10 @@ func VideoProxy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
} else {
|
||||
case constant.ChannelTypeAli:
|
||||
// Video URL is directly in task.FailReason
|
||||
videoURL = task.FailReason
|
||||
default:
|
||||
// Default (Sora, etc.): Use original logic
|
||||
videoURL = fmt.Sprintf("%s/v1/videos/%s/content", baseURL, task.TaskID)
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
|
||||
@@ -27,7 +27,7 @@ type OpenAIVideo struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
|
||||
Error *OpenAIVideoError `json:"error,omitempty"`
|
||||
Metadata map[string]any `json:"meta_data,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (m *OpenAIVideo) SetProgressStr(progress string) {
|
||||
|
||||
@@ -73,20 +73,22 @@ func (t *Task) GetData(v any) error {
|
||||
}
|
||||
|
||||
type Properties struct {
|
||||
Input string `json:"input"`
|
||||
Input string `json:"input"`
|
||||
UpstreamModelName string `json:"upstream_model_name,omitempty"`
|
||||
OriginModelName string `json:"origin_model_name,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Properties) Scan(val interface{}) error {
|
||||
bytesValue, _ := val.([]byte)
|
||||
if len(bytesValue) == 0 {
|
||||
m.Input = ""
|
||||
*m = Properties{}
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(bytesValue, m)
|
||||
}
|
||||
|
||||
func (m Properties) Value() (driver.Value, error) {
|
||||
if m.Input == "" {
|
||||
if m == (Properties{}) {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(m)
|
||||
@@ -127,8 +129,16 @@ type SyncTaskQueryParams struct {
|
||||
func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task {
|
||||
properties := Properties{}
|
||||
privateData := TaskPrivateData{}
|
||||
if relayInfo != nil && relayInfo.ChannelMeta != nil && relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini {
|
||||
privateData.Key = relayInfo.ChannelMeta.ApiKey
|
||||
if relayInfo != nil && relayInfo.ChannelMeta != nil {
|
||||
if relayInfo.ChannelMeta.ChannelType == constant.ChannelTypeGemini {
|
||||
privateData.Key = relayInfo.ChannelMeta.ApiKey
|
||||
}
|
||||
if relayInfo.UpstreamModelName != "" {
|
||||
properties.UpstreamModelName = relayInfo.UpstreamModelName
|
||||
}
|
||||
if relayInfo.OriginModelName != "" {
|
||||
properties.OriginModelName = relayInfo.OriginModelName
|
||||
}
|
||||
}
|
||||
|
||||
t := &Task{
|
||||
|
||||
360
relay/channel/task/ali/adaptor.go
Normal file
360
relay/channel/task/ali/adaptor.go
Normal file
@@ -0,0 +1,360 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
// AliVideoRequest 阿里通义万相视频生成请求
|
||||
type AliVideoRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliVideoInput `json:"input"`
|
||||
Parameters *AliVideoParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// AliVideoInput 视频输入参数
|
||||
type AliVideoInput struct {
|
||||
Prompt string `json:"prompt,omitempty"` // 文本提示词
|
||||
ImgURL string `json:"img_url,omitempty"` // 首帧图像URL或Base64(图生视频)
|
||||
FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频)
|
||||
LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频)
|
||||
AudioURL string `json:"audio_url,omitempty"` // 音频URL(wan2.5支持)
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
|
||||
Template string `json:"template,omitempty"` // 视频特效模板
|
||||
}
|
||||
|
||||
// AliVideoParameters 视频参数
|
||||
type AliVideoParameters struct {
|
||||
Resolution string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P(图生视频、首尾帧生视频)
|
||||
Size string `json:"size,omitempty"` // 尺寸: 如 "832*480"(文生视频)
|
||||
Duration int `json:"duration,omitempty"` // 时长: 3-10秒
|
||||
PromptExtend bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
|
||||
Watermark bool `json:"watermark,omitempty"` // 是否添加水印
|
||||
Audio *bool `json:"audio,omitempty"` // 是否添加音频(wan2.5)
|
||||
Seed int `json:"seed,omitempty"` // 随机数种子
|
||||
}
|
||||
|
||||
// AliVideoResponse 阿里通义万相响应
|
||||
type AliVideoResponse struct {
|
||||
Output AliVideoOutput `json:"output"`
|
||||
RequestID string `json:"request_id"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Usage *AliUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// AliVideoOutput 输出信息
|
||||
type AliVideoOutput struct {
|
||||
TaskID string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
SubmitTime string `json:"submit_time,omitempty"`
|
||||
ScheduledTime string `json:"scheduled_time,omitempty"`
|
||||
EndTime string `json:"end_time,omitempty"`
|
||||
OrigPrompt string `json:"orig_prompt,omitempty"`
|
||||
ActualPrompt string `json:"actual_prompt,omitempty"`
|
||||
VideoURL string `json:"video_url,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// AliUsage 使用统计
|
||||
type AliUsage struct {
|
||||
Duration int `json:"duration,omitempty"`
|
||||
VideoCount int `json:"video_count,omitempty"`
|
||||
SR int `json:"SR,omitempty"`
|
||||
}
|
||||
|
||||
type AliMetadata struct {
|
||||
// Input 相关
|
||||
AudioURL string `json:"audio_url,omitempty"` // 音频URL
|
||||
ImgURL string `json:"img_url,omitempty"` // 图片URL(图生视频)
|
||||
FirstFrameURL string `json:"first_frame_url,omitempty"` // 首帧图片URL(首尾帧生视频)
|
||||
LastFrameURL string `json:"last_frame_url,omitempty"` // 尾帧图片URL(首尾帧生视频)
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"` // 反向提示词
|
||||
Template string `json:"template,omitempty"` // 视频特效模板
|
||||
|
||||
// Parameters 相关
|
||||
Resolution *string `json:"resolution,omitempty"` // 分辨率: 480P/720P/1080P
|
||||
Size *string `json:"size,omitempty"` // 尺寸: 如 "832*480"
|
||||
Duration *int `json:"duration,omitempty"` // 时长
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"` // 是否开启prompt智能改写
|
||||
Watermark *bool `json:"watermark,omitempty"` // 是否添加水印
|
||||
Audio *bool `json:"audio,omitempty"` // 是否添加音频
|
||||
Seed *int `json:"seed,omitempty"` // 随机数种子
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// 阿里通义万相支持 JSON 格式,不使用 multipart
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v1/services/aigc/video-generation/video-synthesis", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers for Ali API
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-DashScope-Async", "enable") // 阿里异步任务必须设置
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
var taskReq relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal_task_request_failed")
|
||||
}
|
||||
aliReq := a.convertToAliRequest(taskReq)
|
||||
|
||||
bodyBytes, err := common.Marshal(aliReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
||||
}
|
||||
|
||||
return bytes.NewReader(bodyBytes), nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToAliRequest(req relaycommon.TaskSubmitReq) *AliVideoRequest {
|
||||
aliReq := &AliVideoRequest{
|
||||
Model: req.Model,
|
||||
Input: AliVideoInput{
|
||||
Prompt: req.Prompt,
|
||||
ImgURL: req.InputReference,
|
||||
},
|
||||
Parameters: &AliVideoParameters{
|
||||
PromptExtend: true, // 默认开启智能改写
|
||||
Watermark: false,
|
||||
},
|
||||
}
|
||||
|
||||
// 处理分辨率映射
|
||||
if req.Size != "" {
|
||||
resolution := strings.ToUpper(req.Size)
|
||||
// 支持 480p, 720p, 1080p 或 480P, 720P, 1080P
|
||||
if !strings.HasSuffix(resolution, "P") {
|
||||
resolution = resolution + "P"
|
||||
}
|
||||
aliReq.Parameters.Resolution = resolution
|
||||
} else {
|
||||
// 根据模型设置默认分辨率
|
||||
if strings.HasPrefix(req.Model, "wan2.5") {
|
||||
aliReq.Parameters.Resolution = "1080P"
|
||||
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-flash") {
|
||||
aliReq.Parameters.Resolution = "720P"
|
||||
} else if strings.HasPrefix(req.Model, "wan2.2-i2v-plus") {
|
||||
aliReq.Parameters.Resolution = "1080P"
|
||||
} else {
|
||||
aliReq.Parameters.Resolution = "720P"
|
||||
}
|
||||
}
|
||||
|
||||
// 处理时长
|
||||
if req.Duration > 0 {
|
||||
aliReq.Parameters.Duration = req.Duration
|
||||
} else {
|
||||
aliReq.Parameters.Duration = 5 // 默认5秒
|
||||
}
|
||||
|
||||
// 从 metadata 中提取额外参数
|
||||
if req.Metadata != nil {
|
||||
if metadataBytes, err := common.Marshal(req.Metadata); err == nil {
|
||||
_ = common.Unmarshal(metadataBytes, aliReq)
|
||||
}
|
||||
}
|
||||
|
||||
return aliReq
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// 解析阿里响应
|
||||
var aliResp AliVideoResponse
|
||||
if err := common.Unmarshal(responseBody, &aliResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if aliResp.Code != "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("%s: %s", aliResp.Code, aliResp.Message), "ali_api_error", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
|
||||
if aliResp.Output.TaskID == "" {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为 OpenAI 格式响应
|
||||
openAIResp := dto.NewOpenAIVideo()
|
||||
openAIResp.ID = aliResp.Output.TaskID
|
||||
openAIResp.Model = c.GetString("model")
|
||||
if openAIResp.Model == "" && info != nil {
|
||||
openAIResp.Model = info.OriginModelName
|
||||
}
|
||||
openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
|
||||
openAIResp.CreatedAt = common.GetTimestamp()
|
||||
|
||||
// 返回 OpenAI 格式
|
||||
c.JSON(http.StatusOK, openAIResp)
|
||||
|
||||
return aliResp.Output.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask 查询任务状态
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/api/v1/tasks/%s", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, uri, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
|
||||
// ParseTaskResult 解析任务结果
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var aliResp AliVideoResponse
|
||||
if err := common.Unmarshal(respBody, &aliResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
|
||||
taskResult := relaycommon.TaskInfo{
|
||||
Code: 0,
|
||||
}
|
||||
|
||||
// 状态映射
|
||||
switch aliResp.Output.TaskStatus {
|
||||
case "PENDING":
|
||||
taskResult.Status = model.TaskStatusQueued
|
||||
case "RUNNING":
|
||||
taskResult.Status = model.TaskStatusInProgress
|
||||
case "SUCCEEDED":
|
||||
taskResult.Status = model.TaskStatusSuccess
|
||||
// 阿里直接返回视频URL,不需要额外的代理端点
|
||||
taskResult.Url = aliResp.Output.VideoURL
|
||||
case "FAILED", "CANCELED", "UNKNOWN":
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
if aliResp.Message != "" {
|
||||
taskResult.Reason = aliResp.Message
|
||||
} else if aliResp.Output.Message != "" {
|
||||
taskResult.Reason = fmt.Sprintf("task failed, code: %s , message: %s", aliResp.Output.Code, aliResp.Output.Message)
|
||||
} else {
|
||||
taskResult.Reason = "task failed"
|
||||
}
|
||||
default:
|
||||
taskResult.Status = model.TaskStatusQueued
|
||||
}
|
||||
|
||||
return &taskResult, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
var aliResp AliVideoResponse
|
||||
if err := common.Unmarshal(task.Data, &aliResp); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal ali response failed")
|
||||
}
|
||||
|
||||
openAIResp := dto.NewOpenAIVideo()
|
||||
openAIResp.ID = task.TaskID
|
||||
openAIResp.Status = convertAliStatus(aliResp.Output.TaskStatus)
|
||||
openAIResp.Model = task.Properties.OriginModelName
|
||||
openAIResp.SetProgressStr(task.Progress)
|
||||
openAIResp.CreatedAt = task.CreatedAt
|
||||
openAIResp.CompletedAt = task.UpdatedAt
|
||||
|
||||
// 设置视频URL(核心字段)
|
||||
openAIResp.SetMetadata("url", aliResp.Output.VideoURL)
|
||||
|
||||
// 错误处理
|
||||
if aliResp.Code != "" {
|
||||
openAIResp.Error = &dto.OpenAIVideoError{
|
||||
Code: aliResp.Code,
|
||||
Message: aliResp.Message,
|
||||
}
|
||||
} else if aliResp.Output.Code != "" {
|
||||
openAIResp.Error = &dto.OpenAIVideoError{
|
||||
Code: aliResp.Output.Code,
|
||||
Message: aliResp.Output.Message,
|
||||
}
|
||||
}
|
||||
|
||||
return common.Marshal(openAIResp)
|
||||
}
|
||||
|
||||
func convertAliStatus(aliStatus string) string {
|
||||
switch aliStatus {
|
||||
case "PENDING":
|
||||
return dto.VideoStatusQueued
|
||||
case "RUNNING":
|
||||
return dto.VideoStatusInProgress
|
||||
case "SUCCEEDED":
|
||||
return dto.VideoStatusCompleted
|
||||
case "FAILED", "CANCELED", "UNKNOWN":
|
||||
return dto.VideoStatusFailed
|
||||
default:
|
||||
return dto.VideoStatusUnknown
|
||||
}
|
||||
}
|
||||
11
relay/channel/task/ali/constants.go
Normal file
11
relay/channel/task/ali/constants.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ali
|
||||
|
||||
var ModelList = []string{
|
||||
"wan2.5-i2v-preview", // 万相2.5 preview(有声视频)推荐
|
||||
"wan2.2-i2v-flash", // 万相2.2极速版(无声视频)
|
||||
"wan2.2-i2v-plus", // 万相2.2专业版(无声视频)
|
||||
"wanx2.1-i2v-plus", // 万相2.1专业版(无声视频)
|
||||
"wanx2.1-i2v-turbo", // 万相2.1极速版(无声视频)
|
||||
}
|
||||
|
||||
var ChannelName = "ali"
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -485,14 +486,16 @@ type TaskRelayInfo struct {
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Seconds string `json:"seconds,omitempty"`
|
||||
InputReference string `json:"input_reference,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) GetPrompt() string {
|
||||
@@ -503,6 +506,38 @@ func (t TaskSubmitReq) HasImage() bool {
|
||||
return len(t.Images) > 0
|
||||
}
|
||||
|
||||
func (t *TaskSubmitReq) UnmarshalJSON(data []byte) error {
|
||||
type Alias TaskSubmitReq
|
||||
aux := &struct {
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(t),
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(aux.Metadata) > 0 {
|
||||
var metadataStr string
|
||||
if err := common.Unmarshal(aux.Metadata, &metadataStr); err == nil && metadataStr != "" {
|
||||
var metadataObj map[string]interface{}
|
||||
if err := common.Unmarshal([]byte(metadataStr), &metadataObj); err == nil {
|
||||
t.Metadata = metadataObj
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var metadataObj map[string]interface{}
|
||||
if err := common.Unmarshal(aux.Metadata, &metadataObj); err == nil {
|
||||
t.Metadata = metadataObj
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
Code int `json:"code"`
|
||||
TaskID string `json:"task_id"`
|
||||
|
||||
@@ -108,62 +108,33 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string
|
||||
}
|
||||
|
||||
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
var prompt string
|
||||
var model string
|
||||
var seconds int
|
||||
var size string
|
||||
var hasInputReference bool
|
||||
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
form, err := common.ParseMultipartFormReusable(c)
|
||||
if err != nil {
|
||||
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
|
||||
}
|
||||
defer form.RemoveAll()
|
||||
var req TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
prompts, ok := form.Value["prompt"]
|
||||
if !ok || len(prompts) == 0 {
|
||||
return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
|
||||
}
|
||||
prompt = prompts[0]
|
||||
|
||||
if _, ok := form.Value["model"]; !ok {
|
||||
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
|
||||
}
|
||||
model = form.Value["model"][0]
|
||||
|
||||
if _, ok := form.File["input_reference"]; ok {
|
||||
hasInputReference = true
|
||||
}
|
||||
|
||||
if ss, ok := form.Value["seconds"]; ok {
|
||||
sInt := common.String2Int(ss[0])
|
||||
if sInt > seconds {
|
||||
seconds = common.String2Int(ss[0])
|
||||
}
|
||||
}
|
||||
|
||||
if sz, ok := form.Value["size"]; ok {
|
||||
size = sz[0]
|
||||
}
|
||||
} else {
|
||||
var req TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
prompt = req.Prompt
|
||||
model = req.Model
|
||||
prompt = req.Prompt
|
||||
model = req.Model
|
||||
seconds, _ = strconv.Atoi(req.Seconds)
|
||||
if seconds == 0 {
|
||||
seconds = req.Duration
|
||||
}
|
||||
if req.InputReference != "" {
|
||||
req.Images = []string{req.InputReference}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
|
||||
}
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if req.HasImage() {
|
||||
hasInputReference = true
|
||||
}
|
||||
if req.HasImage() {
|
||||
hasInputReference = true
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(prompt); taskErr != nil {
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel/perplexity"
|
||||
"github.com/QuantumNous/new-api/relay/channel/siliconflow"
|
||||
"github.com/QuantumNous/new-api/relay/channel/submodel"
|
||||
taskali "github.com/QuantumNous/new-api/relay/channel/task/ali"
|
||||
taskdoubao "github.com/QuantumNous/new-api/relay/channel/task/doubao"
|
||||
taskGemini "github.com/QuantumNous/new-api/relay/channel/task/gemini"
|
||||
taskjimeng "github.com/QuantumNous/new-api/relay/channel/task/jimeng"
|
||||
@@ -133,6 +134,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
}
|
||||
if channelType, err := strconv.ParseInt(string(platform), 10, 64); err == nil {
|
||||
switch channelType {
|
||||
case constant.ChannelTypeAli:
|
||||
return &taskali.TaskAdaptor{}
|
||||
case constant.ChannelTypeKling:
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
|
||||
Reference in New Issue
Block a user