diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 02de99567..964c33256 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -4,6 +4,7 @@ import ( "io" "net/http" "one-api/dto" + "one-api/model" relaycommon "one-api/relay/common" "one-api/types" @@ -49,3 +50,7 @@ type TaskAdaptor interface { ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } + +type OpenAIVideoConverter interface { + ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) +} diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index db9d9c3bf..3ceed42d8 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -184,3 +184,12 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e return &taskResult, nil } + +func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*relaycommon.OpenAIVideo, error) { + openAIVideo := &relaycommon.OpenAIVideo{} + err := json.Unmarshal(task.Data, openAIVideo) + if err != nil { + return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed") + } + return openAIVideo, nil +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 3fc1507b2..7939b48dd 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -550,3 +550,22 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther } return jsonDataAfter, nil } + +type OpenAIVideo struct { + ID string `json:"id"` + TaskID string `json:"task_id,omitempty"` //兼容旧接口 + Object string `json:"object"` + Model string `json:"model"` + Status string `json:"status"` + Progress int `json:"progress"` + CreatedAt int64 `json:"created_at"` + CompletedAt int64 `json:"completed_at,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + Seconds string `json:"seconds,omitempty"` + Size string `json:"size,omitempty"` + RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"` + Error *struct { + Message string `json:"message"` + Code string `json:"code"` + } `json:"error,omitempty"` +} diff --git a/relay/relay_task.go b/relay/relay_task.go index d447a40aa..0c4e9604c 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -11,6 +11,7 @@ import ( "one-api/constant" "one-api/dto" "one-api/model" + "one-api/relay/channel" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" @@ -367,7 +368,21 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d } if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") { - respBody = originTask.Data + adaptor := GetTaskAdaptor(originTask.Platform) + if adaptor == nil { + taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) + return + } + if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok { + openAIVideo, err := converter.ConvertToOpenAIVideo(originTask) + if err != nil { + taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError) + return + } + respBody, _ = json.Marshal(openAIVideo) + return + } + taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented) return } respBody, err = json.Marshal(dto.TaskResponse[any]{