mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
410 lines
12 KiB
Go
410 lines
12 KiB
Go
package kling
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/model"
|
||
|
||
"github.com/samber/lo"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"github.com/pkg/errors"
|
||
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/relay/channel"
|
||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/service"
|
||
)
|
||
|
||
// ============================
|
||
// Request / Response structures
|
||
// ============================
|
||
|
||
type TrajectoryPoint struct {
|
||
X int `json:"x"`
|
||
Y int `json:"y"`
|
||
}
|
||
|
||
type DynamicMask struct {
|
||
Mask string `json:"mask,omitempty"`
|
||
Trajectories []TrajectoryPoint `json:"trajectories,omitempty"`
|
||
}
|
||
|
||
type CameraConfig struct {
|
||
Horizontal float64 `json:"horizontal,omitempty"`
|
||
Vertical float64 `json:"vertical,omitempty"`
|
||
Pan float64 `json:"pan,omitempty"`
|
||
Tilt float64 `json:"tilt,omitempty"`
|
||
Roll float64 `json:"roll,omitempty"`
|
||
Zoom float64 `json:"zoom,omitempty"`
|
||
}
|
||
|
||
type CameraControl struct {
|
||
Type string `json:"type,omitempty"`
|
||
Config *CameraConfig `json:"config,omitempty"`
|
||
}
|
||
|
||
type requestPayload struct {
|
||
Prompt string `json:"prompt,omitempty"`
|
||
Image string `json:"image,omitempty"`
|
||
ImageTail string `json:"image_tail,omitempty"`
|
||
NegativePrompt string `json:"negative_prompt,omitempty"`
|
||
Mode string `json:"mode,omitempty"`
|
||
Duration string `json:"duration,omitempty"`
|
||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||
ModelName string `json:"model_name,omitempty"`
|
||
Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
|
||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||
StaticMask string `json:"static_mask,omitempty"`
|
||
DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"`
|
||
CameraControl *CameraControl `json:"camera_control,omitempty"`
|
||
CallbackUrl string `json:"callback_url,omitempty"`
|
||
ExternalTaskId string `json:"external_task_id,omitempty"`
|
||
}
|
||
|
||
type responsePayload struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
TaskId string `json:"task_id"`
|
||
RequestId string `json:"request_id"`
|
||
Data struct {
|
||
TaskId string `json:"task_id"`
|
||
TaskStatus string `json:"task_status"`
|
||
TaskStatusMsg string `json:"task_status_msg"`
|
||
TaskInfo struct {
|
||
ExternalTaskId string `json:"external_task_id"`
|
||
} `json:"task_info"`
|
||
WatermarkInfo struct {
|
||
Enabled bool `json:"enabled"`
|
||
} `json:"watermark_info"`
|
||
TaskResult struct {
|
||
Videos []struct {
|
||
Id string `json:"id"`
|
||
Url string `json:"url"`
|
||
WatermarkUrl string `json:"watermark_url"`
|
||
Duration string `json:"duration"`
|
||
} `json:"videos"`
|
||
Images []struct {
|
||
Index int `json:"index"`
|
||
Url string `json:"url"`
|
||
WatermarkUrl string `json:"watermark_url"`
|
||
} `json:"images"`
|
||
} `json:"task_result"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
UpdatedAt int64 `json:"updated_at"`
|
||
FinalUnitDeduction string `json:"final_unit_deduction"`
|
||
} `json:"data"`
|
||
}
|
||
|
||
// ============================
|
||
// Adaptor implementation
|
||
// ============================
|
||
|
||
type TaskAdaptor struct {
|
||
taskcommon.BaseBilling
|
||
ChannelType int
|
||
apiKey string
|
||
baseURL string
|
||
}
|
||
|
||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||
a.ChannelType = info.ChannelType
|
||
a.baseURL = info.ChannelBaseUrl
|
||
a.apiKey = info.ApiKey
|
||
|
||
// apiKey format: "access_key|secret_key"
|
||
}
|
||
|
||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||
// Use the standard validation method for TaskSubmitReq
|
||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||
}
|
||
|
||
// BuildRequestURL constructs the upstream URL.
|
||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||
|
||
if isNewAPIRelay(info.ApiKey) {
|
||
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
|
||
}
|
||
|
||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||
}
|
||
|
||
// BuildRequestHeader sets required headers.
|
||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||
token, err := a.createJWTToken()
|
||
if err != nil {
|
||
return fmt.Errorf("failed to create JWT token: %w", err)
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+token)
|
||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||
return nil
|
||
}
|
||
|
||
// BuildRequestBody converts request into Kling specific format.
|
||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||
v, exists := c.Get("task_request")
|
||
if !exists {
|
||
return nil, fmt.Errorf("request not found in context")
|
||
}
|
||
req := v.(relaycommon.TaskSubmitReq)
|
||
|
||
body, err := a.convertToRequestPayload(&req, info)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if body.Image == "" && body.ImageTail == "" {
|
||
c.Set("action", constant.TaskActionTextGenerate)
|
||
}
|
||
data, err := common.Marshal(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return bytes.NewReader(data), nil
|
||
}
|
||
|
||
// DoRequest delegates to common helper.
|
||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||
if action := c.GetString("action"); action != "" {
|
||
info.Action = action
|
||
}
|
||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||
}
|
||
|
||
// DoResponse handles upstream response, returns taskID etc.
|
||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
var kResp responsePayload
|
||
err = common.Unmarshal(responseBody, &kResp)
|
||
if err != nil {
|
||
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
if kResp.Code != 0 {
|
||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest)
|
||
return
|
||
}
|
||
ov := dto.NewOpenAIVideo()
|
||
ov.ID = info.PublicTaskID
|
||
ov.TaskID = info.PublicTaskID
|
||
ov.CreatedAt = time.Now().Unix()
|
||
ov.Model = info.OriginModelName
|
||
c.JSON(http.StatusOK, ov)
|
||
return kResp.Data.TaskId, responseBody, nil
|
||
}
|
||
|
||
// FetchTask fetch task status
|
||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||
taskID, ok := body["task_id"].(string)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid task_id")
|
||
}
|
||
action, ok := body["action"].(string)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid action")
|
||
}
|
||
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
|
||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||
if isNewAPIRelay(key) {
|
||
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
|
||
}
|
||
|
||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
token, err := a.createJWTTokenWithKey(key)
|
||
if err != nil {
|
||
token = key
|
||
}
|
||
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("Authorization", "Bearer "+token)
|
||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||
|
||
client, err := service.GetHttpClientWithProxy(proxy)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||
}
|
||
return client.Do(req)
|
||
}
|
||
|
||
func (a *TaskAdaptor) GetModelList() []string {
|
||
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||
}
|
||
|
||
func (a *TaskAdaptor) GetChannelName() string {
|
||
return "kling"
|
||
}
|
||
|
||
// ============================
|
||
// helpers
|
||
// ============================
|
||
|
||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
|
||
r := requestPayload{
|
||
Prompt: req.Prompt,
|
||
Image: req.Image,
|
||
Mode: taskcommon.DefaultString(req.Mode, "std"),
|
||
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
|
||
AspectRatio: a.getAspectRatio(req.Size),
|
||
ModelName: info.UpstreamModelName,
|
||
Model: info.UpstreamModelName,
|
||
CfgScale: 0.5,
|
||
StaticMask: "",
|
||
DynamicMasks: []DynamicMask{},
|
||
CameraControl: nil,
|
||
CallbackUrl: "",
|
||
ExternalTaskId: "",
|
||
}
|
||
if r.ModelName == "" {
|
||
r.ModelName = "kling-v1"
|
||
r.Model = "kling-v1"
|
||
}
|
||
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
|
||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||
}
|
||
return &r, nil
|
||
}
|
||
|
||
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||
switch size {
|
||
case "1024x1024", "512x512":
|
||
return "1:1"
|
||
case "1280x720", "1920x1080":
|
||
return "16:9"
|
||
case "720x1280", "1080x1920":
|
||
return "9:16"
|
||
default:
|
||
return "1:1"
|
||
}
|
||
}
|
||
|
||
// ============================
|
||
// JWT helpers
|
||
// ============================
|
||
|
||
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||
return a.createJWTTokenWithKey(a.apiKey)
|
||
}
|
||
|
||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||
if isNewAPIRelay(apiKey) {
|
||
return apiKey, nil // new api relay
|
||
}
|
||
keyParts := strings.Split(apiKey, "|")
|
||
if len(keyParts) != 2 {
|
||
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
|
||
}
|
||
accessKey := strings.TrimSpace(keyParts[0])
|
||
if len(keyParts) == 1 {
|
||
return accessKey, nil
|
||
}
|
||
secretKey := strings.TrimSpace(keyParts[1])
|
||
now := time.Now().Unix()
|
||
claims := jwt.MapClaims{
|
||
"iss": accessKey,
|
||
"exp": now + 1800, // 30 minutes
|
||
"nbf": now - 5,
|
||
}
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
token.Header["typ"] = "JWT"
|
||
return token.SignedString([]byte(secretKey))
|
||
}
|
||
|
||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||
taskInfo := &relaycommon.TaskInfo{}
|
||
resPayload := responsePayload{}
|
||
err := common.Unmarshal(respBody, &resPayload)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||
}
|
||
taskInfo.Code = resPayload.Code
|
||
taskInfo.TaskID = resPayload.Data.TaskId
|
||
taskInfo.Reason = resPayload.Data.TaskStatusMsg
|
||
//任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
|
||
status := resPayload.Data.TaskStatus
|
||
switch status {
|
||
case "submitted":
|
||
taskInfo.Status = model.TaskStatusSubmitted
|
||
case "processing":
|
||
taskInfo.Status = model.TaskStatusInProgress
|
||
case "succeed":
|
||
taskInfo.Status = model.TaskStatusSuccess
|
||
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
|
||
video := videos[0]
|
||
taskInfo.Url = video.Url
|
||
}
|
||
if tokens, err := strconv.ParseFloat(resPayload.Data.FinalUnitDeduction, 64); err == nil {
|
||
rounded := int(math.Ceil(tokens))
|
||
if rounded > 0 {
|
||
taskInfo.CompletionTokens = rounded
|
||
taskInfo.TotalTokens = rounded
|
||
}
|
||
}
|
||
case "failed":
|
||
taskInfo.Status = model.TaskStatusFailure
|
||
default:
|
||
return nil, fmt.Errorf("unknown task status: %s", status)
|
||
}
|
||
return taskInfo, nil
|
||
}
|
||
|
||
func isNewAPIRelay(apiKey string) bool {
|
||
return strings.HasPrefix(apiKey, "sk-")
|
||
}
|
||
|
||
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
|
||
var klingResp responsePayload
|
||
if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
|
||
return nil, errors.Wrap(err, "unmarshal kling task data failed")
|
||
}
|
||
|
||
openAIVideo := dto.NewOpenAIVideo()
|
||
openAIVideo.ID = originTask.TaskID
|
||
openAIVideo.Status = originTask.Status.ToVideoStatus()
|
||
openAIVideo.SetProgressStr(originTask.Progress)
|
||
openAIVideo.CreatedAt = klingResp.Data.CreatedAt
|
||
openAIVideo.CompletedAt = klingResp.Data.UpdatedAt
|
||
|
||
if len(klingResp.Data.TaskResult.Videos) > 0 {
|
||
video := klingResp.Data.TaskResult.Videos[0]
|
||
if video.Url != "" {
|
||
openAIVideo.SetMetadata("url", video.Url)
|
||
}
|
||
if video.Duration != "" {
|
||
openAIVideo.Seconds = video.Duration
|
||
}
|
||
}
|
||
|
||
if klingResp.Code != 0 && klingResp.Message != "" {
|
||
openAIVideo.Error = &dto.OpenAIVideoError{
|
||
Message: klingResp.Message,
|
||
Code: fmt.Sprintf("%d", klingResp.Code),
|
||
}
|
||
}
|
||
return common.Marshal(openAIVideo)
|
||
}
|