Files
new-api/relay/channel/task/kling/adaptor.go
CaIon ec5c6b28ea feat(task): add model redirection, per-call billing, and multipart retry fix for async tasks
1. Async task model redirection (aligned with sync tasks):
   - Integrate ModelMappedHelper in RelayTaskSubmit after model name
     determination, populating OriginModelName / UpstreamModelName on RelayInfo.
   - All task adaptors now send UpstreamModelName to upstream providers:
     - Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
     - Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
     - Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
       and unconditionally uses info.UpstreamModelName.
     - Sora: BuildRequestBody parses JSON and multipart bodies to replace
       the "model" field with UpstreamModelName.
   - Frontend log visibility: LogTaskConsumption and taskBillingOther now
     emit is_model_mapped / upstream_model_name in the "other" JSON field.
   - Billing safety: RecalculateTaskQuotaByTokens reads model name from
     BillingContext.OriginModelName (via taskModelName) instead of
     task.Data["model"], preventing billing leaks from upstream model names.

2. Per-call billing (TaskPricePatches lifecycle):
   - Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
     bool field, populated from TaskPricePatches at submission time.
   - settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
     skipping both adaptor adjustments and token-based recalculation.
   - Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
     consistently in controller/relay.go for billing context and logging.

3. Multipart retry boundary mismatch fix:
   - Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
     new boundary and overwrites c.Request.Header["Content-Type"], subsequent
     calls to ParseMultipartFormReusable on retry would parse the cached
     original body with the wrong boundary, causing "NextPart: EOF".
   - Fix: ParseMultipartFormReusable now caches the original Content-Type in
     gin context key "_original_multipart_ct" on first call and reuses it for
     all subsequent parses, making multipart parsing retry-safe globally.
   - Sora adaptor reverted to the standard pattern (direct header set/get),
     which is now safe thanks to the root fix.

4. Tests:
   - task_billing_test.go: update makeTask to use OriginModelName; add
     PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
     add non-per-call adaptor adjustment test with refund verification.
2026-02-22 16:33:00 +08:00

388 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package kling
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/samber/lo"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
)
// ============================
// Request / Response structures
// ============================
type TrajectoryPoint struct {
X int `json:"x"`
Y int `json:"y"`
}
type DynamicMask struct {
Mask string `json:"mask,omitempty"`
Trajectories []TrajectoryPoint `json:"trajectories,omitempty"`
}
type CameraConfig struct {
Horizontal float64 `json:"horizontal,omitempty"`
Vertical float64 `json:"vertical,omitempty"`
Pan float64 `json:"pan,omitempty"`
Tilt float64 `json:"tilt,omitempty"`
Roll float64 `json:"roll,omitempty"`
Zoom float64 `json:"zoom,omitempty"`
}
type CameraControl struct {
Type string `json:"type,omitempty"`
Config *CameraConfig `json:"config,omitempty"`
}
type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Image string `json:"image,omitempty"`
ImageTail string `json:"image_tail,omitempty"`
NegativePrompt string `json:"negative_prompt,omitempty"`
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
ModelName string `json:"model_name,omitempty"`
Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
CfgScale float64 `json:"cfg_scale,omitempty"`
StaticMask string `json:"static_mask,omitempty"`
DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"`
CameraControl *CameraControl `json:"camera_control,omitempty"`
CallbackUrl string `json:"callback_url,omitempty"`
ExternalTaskId string `json:"external_task_id,omitempty"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
Data struct {
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
TaskStatusMsg string `json:"task_status_msg"`
TaskResult struct {
Videos []struct {
Id string `json:"id"`
Url string `json:"url"`
Duration string `json:"duration"`
} `json:"videos"`
} `json:"task_result"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
} `json:"data"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
// apiKey format: "access_key|secret_key"
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
// Use the standard validation method for TaskSubmitReq
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
if isNewAPIRelay(info.ApiKey) {
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
}
return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
token, err := a.createJWTToken()
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return nil
}
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req, info)
if err != nil {
return nil, err
}
if body.Image == "" && body.ImageTail == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
data, err := common.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
if action := c.GetString("action"); action != "" {
info.Action = action
}
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
var kResp responsePayload
err = common.Unmarshal(responseBody, &kResp)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
return
}
if kResp.Code != 0 {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("%s", kResp.Message), "task_failed", http.StatusBadRequest)
return
}
ov := dto.NewOpenAIVideo()
ov.ID = info.PublicTaskID
ov.TaskID = info.PublicTaskID
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
return kResp.Data.TaskId, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
action, ok := body["action"].(string)
if !ok {
return nil, fmt.Errorf("invalid action")
}
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
if isNewAPIRelay(key) {
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
token, err := a.createJWTTokenWithKey(key)
if err != nil {
token = key
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
client, err := service.GetHttpClientWithProxy(proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
return client.Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "kling"
}
// ============================
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq, info *relaycommon.RelayInfo) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: taskcommon.DefaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", taskcommon.DefaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
ModelName: info.UpstreamModelName,
Model: info.UpstreamModelName,
CfgScale: 0.5,
StaticMask: "",
DynamicMasks: []DynamicMask{},
CameraControl: nil,
CallbackUrl: "",
ExternalTaskId: "",
}
if r.ModelName == "" {
r.ModelName = "kling-v1"
r.Model = "kling-v1"
}
if err := taskcommon.UnmarshalMetadata(req.Metadata, &r); err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
switch size {
case "1024x1024", "512x512":
return "1:1"
case "1280x720", "1920x1080":
return "16:9"
case "720x1280", "1080x1920":
return "9:16"
default:
return "1:1"
}
}
// ============================
// JWT helpers
// ============================
func (a *TaskAdaptor) createJWTToken() (string, error) {
return a.createJWTTokenWithKey(a.apiKey)
}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
if isNewAPIRelay(apiKey) {
return apiKey, nil // new api relay
}
keyParts := strings.Split(apiKey, "|")
if len(keyParts) != 2 {
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
}
accessKey := strings.TrimSpace(keyParts[0])
if len(keyParts) == 1 {
return accessKey, nil
}
secretKey := strings.TrimSpace(keyParts[1])
now := time.Now().Unix()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now + 1800, // 30 minutes
"nbf": now - 5,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["typ"] = "JWT"
return token.SignedString([]byte(secretKey))
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
taskInfo := &relaycommon.TaskInfo{}
resPayload := responsePayload{}
err := common.Unmarshal(respBody, &resPayload)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
taskInfo.Code = resPayload.Code
taskInfo.TaskID = resPayload.Data.TaskId
taskInfo.Reason = resPayload.Data.TaskStatusMsg
//任务状态枚举值submitted已提交、processing处理中、succeed成功、failed失败
status := resPayload.Data.TaskStatus
switch status {
case "submitted":
taskInfo.Status = model.TaskStatusSubmitted
case "processing":
taskInfo.Status = model.TaskStatusInProgress
case "succeed":
taskInfo.Status = model.TaskStatusSuccess
case "failed":
taskInfo.Status = model.TaskStatusFailure
default:
return nil, fmt.Errorf("unknown task status: %s", status)
}
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
video := videos[0]
taskInfo.Url = video.Url
}
return taskInfo, nil
}
func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
var klingResp responsePayload
if err := common.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
}
openAIVideo := dto.NewOpenAIVideo()
openAIVideo.ID = originTask.TaskID
openAIVideo.Status = originTask.Status.ToVideoStatus()
openAIVideo.SetProgressStr(originTask.Progress)
openAIVideo.CreatedAt = klingResp.Data.CreatedAt
openAIVideo.CompletedAt = klingResp.Data.UpdatedAt
if len(klingResp.Data.TaskResult.Videos) > 0 {
video := klingResp.Data.TaskResult.Videos[0]
if video.Url != "" {
openAIVideo.SetMetadata("url", video.Url)
}
if video.Duration != "" {
openAIVideo.Seconds = video.Duration
}
}
if klingResp.Code != 0 && klingResp.Message != "" {
openAIVideo.Error = &dto.OpenAIVideoError{
Message: klingResp.Message,
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
return common.Marshal(openAIVideo)
}