diff --git a/constant/channel.go b/constant/channel.go index 34fb20f46..7d8893c1d 100644 --- a/constant/channel.go +++ b/constant/channel.go @@ -51,9 +51,9 @@ const ( ChannelTypeJimeng = 51 ChannelTypeVidu = 52 ChannelTypeSubmodel = 53 + ChannelTypeDoubaoVideo = 54 ChannelTypeDummy // this one is only for count, do not add any channel after this - ) var ChannelBaseURLs = []string{ @@ -111,4 +111,5 @@ var ChannelBaseURLs = []string{ "https://visual.volcengineapi.com", //51 "https://api.vidu.cn", //52 "https://llm.submodel.ai", //53 + "https://ark.cn-beijing.volces.com", //54 } diff --git a/controller/channel-test.go b/controller/channel-test.go index b3a3be4eb..ff1e8cef4 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -70,6 +70,12 @@ func testChannel(channel *model.Channel, testModel string, endpointType string) newAPIError: nil, } } + if channel.Type == constant.ChannelTypeDoubaoVideo { + return testResult{ + localErr: errors.New("doubao video channel test is not supported"), + newAPIError: nil, + } + } if channel.Type == constant.ChannelTypeVidu { return testResult{ localErr: errors.New("vidu channel test is not supported"), diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go new file mode 100644 index 000000000..9b40a249a --- /dev/null +++ b/relay/channel/task/doubao/adaptor.go @@ -0,0 +1,245 @@ +package doubao + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/constant" + "one-api/dto" + "one-api/model" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/service" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +// ============================ +// Request / Response structures +// ============================ + +type ContentItem struct { + Type string `json:"type"` // "text" or "image_url" + Text string `json:"text,omitempty"` // for text type + ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type +} + +type ImageURL struct { + URL string `json:"url"` +} + +type requestPayload struct { + Model string `json:"model"` + Content []ContentItem `json:"content"` +} + +type responsePayload struct { + ID string `json:"id"` // task_id +} + +type responseTask struct { + ID string `json:"id"` + Model string `json:"model"` + Status string `json:"status"` + Content struct { + VideoURL string `json:"video_url"` + } `json:"content"` + Seed int `json:"seed"` + Resolution string `json:"resolution"` + Duration int `json:"duration"` + Ratio string `json:"ratio"` + FramesPerSecond int `json:"framespersecond"` + Usage struct { + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +// ============================ +// Adaptor implementation +// ============================ + +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string +} + +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + // Accept only POST /v1/video/generations as "generate" action. + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil +} + +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+a.apiKey) + return nil +} + +// BuildRequestBody converts request into Doubao specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { + v, exists := c.Get("task_request") + if !exists { + return nil, fmt.Errorf("request not found in context") + } + req := v.(relaycommon.TaskSubmitReq) + + body, err := a.convertToRequestPayload(&req) + if err != nil { + return nil, errors.Wrap(err, "convert request payload failed") + } + data, err := json.Marshal(body) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoTaskApiRequest(a, c, info, requestBody) +} + +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + return + } + _ = resp.Body.Close() + + // Parse Doubao response + var dResp responsePayload + if err := json.Unmarshal(responseBody, &dResp); err != nil { + taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError) + return + } + + if dResp.ID == "" { + taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError) + return + } + + c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID}) + return dResp.ID, responseBody, nil +} + +// FetchTask fetch task status +func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { + taskID, ok := body["task_id"].(string) + if !ok { + return nil, fmt.Errorf("invalid task_id") + } + + uri := fmt.Sprintf("%s/api/v3/contents/generations/tasks/%s", baseUrl, taskID) + + req, err := http.NewRequest(http.MethodGet, uri, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+key) + + return service.GetHttpClient().Do(req) +} + +func (a *TaskAdaptor) GetModelList() []string { + return ModelList +} + +func (a *TaskAdaptor) GetChannelName() string { + return ChannelName +} + +func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) { + r := requestPayload{ + Model: req.Model, + Content: []ContentItem{}, + } + + // Add text prompt + if req.Prompt != "" { + r.Content = append(r.Content, ContentItem{ + Type: "text", + Text: req.Prompt, + }) + } + + // Add images if present + if req.HasImage() { + for _, imgURL := range req.Images { + r.Content = append(r.Content, ContentItem{ + Type: "image_url", + ImageURL: &ImageURL{ + URL: imgURL, + }, + }) + } + } + + // TODO: Add support for additional parameters from metadata + // such as ratio, duration, seed, etc. + // metadata := req.Metadata + // if metadata != nil { + // // Parse and apply metadata parameters + // } + + return &r, nil +} + +func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { + resTask := responseTask{} + if err := json.Unmarshal(respBody, &resTask); err != nil { + return nil, errors.Wrap(err, "unmarshal task result failed") + } + + taskResult := relaycommon.TaskInfo{ + Code: 0, + } + + // Map Doubao status to internal status + switch resTask.Status { + case "pending", "queued": + taskResult.Status = model.TaskStatusQueued + taskResult.Progress = "10%" + case "processing": + taskResult.Status = model.TaskStatusInProgress + taskResult.Progress = "50%" + case "succeeded": + taskResult.Status = model.TaskStatusSuccess + taskResult.Progress = "100%" + taskResult.Url = resTask.Content.VideoURL + case "failed": + taskResult.Status = model.TaskStatusFailure + taskResult.Progress = "100%" + taskResult.Reason = "task failed" + default: + // Unknown status, treat as processing + taskResult.Status = model.TaskStatusInProgress + taskResult.Progress = "30%" + } + + return &taskResult, nil +} diff --git a/relay/channel/task/doubao/constants.go b/relay/channel/task/doubao/constants.go new file mode 100644 index 000000000..74b416c6d --- /dev/null +++ b/relay/channel/task/doubao/constants.go @@ -0,0 +1,9 @@ +package doubao + +var ModelList = []string{ + "doubao-seedance-1-0-pro-250528", + "doubao-seedance-1-0-lite-t2v", + "doubao-seedance-1-0-lite-i2v", +} + +var ChannelName = "doubao-video" diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 406074c58..c8fd51a11 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -1,6 +1,7 @@ package relay import ( + "github.com/gin-gonic/gin" "one-api/constant" "one-api/relay/channel" "one-api/relay/channel/ali" @@ -24,6 +25,8 @@ import ( "one-api/relay/channel/palm" "one-api/relay/channel/perplexity" "one-api/relay/channel/siliconflow" + "one-api/relay/channel/submodel" + taskdoubao "one-api/relay/channel/task/doubao" taskjimeng "one-api/relay/channel/task/jimeng" "one-api/relay/channel/task/kling" "one-api/relay/channel/task/suno" @@ -37,8 +40,6 @@ import ( "one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu_4v" "strconv" - "one-api/relay/channel/submodel" - "github.com/gin-gonic/gin" ) func GetAdaptor(apiType int) channel.Adaptor { @@ -134,6 +135,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor { return &taskvertex.TaskAdaptor{} case constant.ChannelTypeVidu: return &taskVidu.TaskAdaptor{} + case constant.ChannelTypeDoubaoVideo: + return &taskdoubao.TaskAdaptor{} } } return nil diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 9ed2e8b5e..3b376ed35 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -164,6 +164,11 @@ export const CHANNEL_OPTIONS = [ color: 'blue', label: 'SubModel', }, + { + value: 54, + color: 'blue', + label: '豆包视频', + }, ]; export const MODEL_TABLE_PAGE_SIZE = 10; diff --git a/web/src/helpers/render.jsx b/web/src/helpers/render.jsx index 82d164b38..25afacec0 100644 --- a/web/src/helpers/render.jsx +++ b/web/src/helpers/render.jsx @@ -337,6 +337,8 @@ export function getChannelIcon(channelType) { return ; case 51: // 即梦 Jimeng return ; + case 54: // 豆包视频 Doubao Video + return ; case 8: // 自定义渠道 case 22: // 知识库:FastGPT return ;