Merge pull request #2004 from feitianbubu/pr/openai-sdk-kling

支持可灵使用openai sdk生成视频
This commit is contained in:
Calcium-Ion
2025-10-11 11:05:11 +08:00
committed by GitHub
7 changed files with 157 additions and 30 deletions

View File

@@ -174,14 +174,22 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
relayMode := relayconstant.RelayModeUnknown
if c.Request.Method == http.MethodPost {
relayMode = relayconstant.RelayModeVideoSubmit
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
defer form.RemoveAll()
if form != nil {
if values, ok := form.Value["model"]; ok && len(values) > 0 {
modelRequest.Model = values[0]
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "multipart/form-data") {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
defer form.RemoveAll()
if form != nil {
if values, ok := form.Value["model"]; ok && len(values) > 0 {
modelRequest.Model = values[0]
}
}
} else if strings.HasPrefix(contentType, "application/json") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return nil, false, errors.New("无效的video请求, " + err.Error())
}
}
} else if c.Request.Method == http.MethodGet {

View File

@@ -4,6 +4,7 @@ import (
"io"
"net/http"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/types"
@@ -49,3 +50,7 @@ type TaskAdaptor interface {
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}
type OpenAIVideoConverter interface {
ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error)
}

View File

@@ -7,9 +7,11 @@ import (
"io"
"net/http"
"one-api/model"
"strconv"
"strings"
"time"
"github.com/bytedance/gopkg/util/logger"
"github.com/samber/lo"
"github.com/gin-gonic/gin"
@@ -303,14 +305,6 @@ func (a *TaskAdaptor) createJWTToken() (string, error) {
return a.createJWTTokenWithKey(a.apiKey)
}
//func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
// parts := strings.Split(apiKey, "|")
// if len(parts) != 2 {
// return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
// }
// return a.createJWTTokenWithKey(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
//}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
if isNewAPIRelay(apiKey) {
return apiKey, nil // new api relay
@@ -369,3 +363,50 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) (*relaycommon.OpenAIVideo, error) {
var klingResp responsePayload
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
}
convertProgress := func(progress string) int {
progress = strings.TrimSuffix(progress, "%")
p, err := strconv.Atoi(progress)
if err != nil {
logger.Warnf("convert progress failed, progress: %s, err: %v", progress, err)
}
return p
}
openAIVideo := &relaycommon.OpenAIVideo{
ID: klingResp.Data.TaskId,
Object: "video",
//Model: "kling-v1", //todo save model
Status: string(originTask.Status),
CreatedAt: klingResp.Data.CreatedAt,
CompletedAt: klingResp.Data.UpdatedAt,
Metadata: make(map[string]any),
Progress: convertProgress(originTask.Progress),
}
// 处理视频 URL
if len(klingResp.Data.TaskResult.Videos) > 0 {
video := klingResp.Data.TaskResult.Videos[0]
if video.Url != "" {
openAIVideo.Metadata["url"] = video.Url
}
if video.Duration != "" {
openAIVideo.Seconds = video.Duration
}
}
if klingResp.Code != 0 && klingResp.Message != "" {
openAIVideo.Error = &relaycommon.OpenAIVideoError{
Message: klingResp.Message,
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
return openAIVideo, nil
}

View File

@@ -184,3 +184,12 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
return &taskResult, nil
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) (*relaycommon.OpenAIVideo, error) {
openAIVideo := &relaycommon.OpenAIVideo{}
err := json.Unmarshal(task.Data, openAIVideo)
if err != nil {
return nil, errors.Wrap(err, "unmarshal to OpenAIVideo failed")
}
return openAIVideo, nil
}

View File

@@ -550,3 +550,24 @@ func RemoveDisabledFields(jsonData []byte, channelOtherSettings dto.ChannelOther
}
return jsonDataAfter, nil
}
type OpenAIVideo struct {
ID string `json:"id"`
TaskID string `json:"task_id,omitempty"` //兼容旧接口 待废弃
Object string `json:"object"`
Model string `json:"model"`
Status string `json:"status"`
Progress int `json:"progress"`
CreatedAt int64 `json:"created_at"`
CompletedAt int64 `json:"completed_at,omitempty"`
ExpiresAt int64 `json:"expires_at,omitempty"`
Seconds string `json:"seconds,omitempty"`
Size string `json:"size,omitempty"`
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
Error *OpenAIVideoError `json:"error,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
type OpenAIVideoError struct {
Message string `json:"message"`
Code string `json:"code"`
}

View File

@@ -106,25 +106,53 @@ func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string
}
func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
}
defer form.RemoveAll()
contentType := c.GetHeader("Content-Type")
var prompt string
var hasInputReference bool
prompts, ok := form.Value["prompt"]
if !ok || len(prompts) == 0 {
return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
if strings.HasPrefix(contentType, "multipart/form-data") {
form, err := common.ParseMultipartFormReusable(c)
if err != nil {
return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
}
defer form.RemoveAll()
prompts, ok := form.Value["prompt"]
if !ok || len(prompts) == 0 {
return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
}
prompt = prompts[0]
if _, ok := form.Value["model"]; !ok {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
if _, ok := form.File["input_reference"]; ok {
hasInputReference = true
}
} else {
var req TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return createTaskError(err, "invalid_json", http.StatusBadRequest, true)
}
prompt = req.Prompt
if strings.TrimSpace(req.Model) == "" {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
if req.HasImage() {
hasInputReference = true
}
}
if taskErr := validatePrompt(prompts[0]); taskErr != nil {
if taskErr := validatePrompt(prompt); taskErr != nil {
return taskErr
}
if _, ok := form.Value["model"]; !ok {
return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
}
action := constant.TaskActionTextGenerate
if _, ok := form.File["input_reference"]; ok {
if hasInputReference {
action = constant.TaskActionGenerate
}
info.Action = action

View File

@@ -11,6 +11,7 @@ import (
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
@@ -367,7 +368,21 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
}
if strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") {
respBody = originTask.Data
adaptor := GetTaskAdaptor(originTask.Platform)
if adaptor == nil {
taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest)
return
}
if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok {
openAIVideo, err := converter.ConvertToOpenAIVideo(originTask)
if err != nil {
taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError)
return
}
respBody, _ = json.Marshal(openAIVideo)
return
}
taskResp = service.TaskErrorWrapperLocal(errors.New(fmt.Sprintf("not_implemented:%s", originTask.Platform)), "not_implemented", http.StatusNotImplemented)
return
}
respBody, err = json.Marshal(dto.TaskResponse[any]{