mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:25:00 +00:00
Add three billing lifecycle methods to the TaskAdaptor interface: - EstimateBilling: compute OtherRatios from user request before pricing - AdjustBillingOnSubmit: adjust ratios from upstream submit response - AdjustBillingOnComplete: determine final quota at task terminal state Introduce BaseBilling as embeddable no-op default for adaptors without custom billing. Move Sora/Ali OtherRatios logic from shared validation into per-adaptor EstimateBilling implementations. Add TaskBillingContext to persist pricing params (model_price, group_ratio, other_ratios) in task private data for async polling settlement. Extract RecalculateTaskQuota as a general-purpose delta settlement function and unify polling billing via settleTaskBillingOnComplete (adaptor-first, then token-based fallback).
255 lines
7.4 KiB
Go
255 lines
7.4 KiB
Go
package sora
|
||
|
||
import (
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/dto"
|
||
"github.com/QuantumNous/new-api/model"
|
||
"github.com/QuantumNous/new-api/relay/channel"
|
||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/service"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/pkg/errors"
|
||
)
|
||
|
||
// ============================
|
||
// Request / Response structures
|
||
// ============================
|
||
|
||
type ContentItem struct {
|
||
Type string `json:"type"` // "text" or "image_url"
|
||
Text string `json:"text,omitempty"` // for text type
|
||
ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
||
}
|
||
|
||
type ImageURL struct {
|
||
URL string `json:"url"`
|
||
}
|
||
|
||
type responseTask struct {
|
||
ID string `json:"id"`
|
||
TaskID string `json:"task_id,omitempty"` //兼容旧接口
|
||
Object string `json:"object"`
|
||
Model string `json:"model"`
|
||
Status string `json:"status"`
|
||
Progress int `json:"progress"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
CompletedAt int64 `json:"completed_at,omitempty"`
|
||
ExpiresAt int64 `json:"expires_at,omitempty"`
|
||
Seconds string `json:"seconds,omitempty"`
|
||
Size string `json:"size,omitempty"`
|
||
RemixedFromVideoID string `json:"remixed_from_video_id,omitempty"`
|
||
Error *struct {
|
||
Message string `json:"message"`
|
||
Code string `json:"code"`
|
||
} `json:"error,omitempty"`
|
||
}
|
||
|
||
// ============================
|
||
// Adaptor implementation
|
||
// ============================
|
||
|
||
type TaskAdaptor struct {
|
||
taskcommon.BaseBilling
|
||
ChannelType int
|
||
apiKey string
|
||
baseURL string
|
||
}
|
||
|
||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||
a.ChannelType = info.ChannelType
|
||
a.baseURL = info.ChannelBaseUrl
|
||
a.apiKey = info.ApiKey
|
||
}
|
||
|
||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||
var req relaycommon.TaskSubmitReq
|
||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||
}
|
||
if strings.TrimSpace(req.Prompt) == "" {
|
||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||
}
|
||
// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
|
||
c.Set("task_request", req)
|
||
return nil
|
||
}
|
||
|
||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||
if info.Action == constant.TaskActionRemix {
|
||
return validateRemixRequest(c)
|
||
}
|
||
return relaycommon.ValidateMultipartDirect(c, info)
|
||
}
|
||
|
||
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
|
||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
|
||
if info.Action == constant.TaskActionRemix {
|
||
return nil
|
||
}
|
||
|
||
req, err := relaycommon.GetTaskRequest(c)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
|
||
seconds, _ := strconv.Atoi(req.Seconds)
|
||
if seconds == 0 {
|
||
seconds = req.Duration
|
||
}
|
||
if seconds <= 0 {
|
||
seconds = 4
|
||
}
|
||
|
||
size := req.Size
|
||
if size == "" {
|
||
size = "720x1280"
|
||
}
|
||
|
||
ratios := map[string]float64{
|
||
"seconds": float64(seconds),
|
||
"size": 1,
|
||
}
|
||
if size == "1792x1024" || size == "1024x1792" {
|
||
ratios["size"] = 1.666667
|
||
}
|
||
return ratios
|
||
}
|
||
|
||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||
if info.Action == constant.TaskActionRemix {
|
||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||
}
|
||
return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
||
}
|
||
|
||
// BuildRequestHeader sets required headers.
|
||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||
req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||
return nil
|
||
}
|
||
|
||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||
storage, err := common.GetBodyStorage(c)
|
||
if err != nil {
|
||
return nil, errors.Wrap(err, "get_request_body_failed")
|
||
}
|
||
return common.ReaderOnly(storage), nil
|
||
}
|
||
|
||
// DoRequest delegates to common helper.
|
||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||
}
|
||
|
||
// DoResponse handles upstream response, returns taskID etc.
|
||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
_ = resp.Body.Close()
|
||
|
||
// Parse Sora response
|
||
var dResp responseTask
|
||
if err := common.Unmarshal(responseBody, &dResp); err != nil {
|
||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
upstreamID := dResp.ID
|
||
if upstreamID == "" {
|
||
upstreamID = dResp.TaskID
|
||
}
|
||
if upstreamID == "" {
|
||
taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
||
return
|
||
}
|
||
|
||
// 使用公开 task_xxxx ID 返回给客户端
|
||
dResp.ID = info.PublicTaskID
|
||
dResp.TaskID = info.PublicTaskID
|
||
c.JSON(http.StatusOK, dResp)
|
||
return upstreamID, responseBody, nil
|
||
}
|
||
|
||
// FetchTask fetch task status
|
||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||
taskID, ok := body["task_id"].(string)
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid task_id")
|
||
}
|
||
|
||
uri := fmt.Sprintf("%s/v1/videos/%s", baseUrl, taskID)
|
||
|
||
req, err := http.NewRequest(http.MethodGet, uri, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+key)
|
||
|
||
client, err := service.GetHttpClientWithProxy(proxy)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("new proxy http client failed: %w", err)
|
||
}
|
||
return client.Do(req)
|
||
}
|
||
|
||
func (a *TaskAdaptor) GetModelList() []string {
|
||
return ModelList
|
||
}
|
||
|
||
func (a *TaskAdaptor) GetChannelName() string {
|
||
return ChannelName
|
||
}
|
||
|
||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||
resTask := responseTask{}
|
||
if err := common.Unmarshal(respBody, &resTask); err != nil {
|
||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||
}
|
||
|
||
taskResult := relaycommon.TaskInfo{
|
||
Code: 0,
|
||
}
|
||
|
||
switch resTask.Status {
|
||
case "queued", "pending":
|
||
taskResult.Status = model.TaskStatusQueued
|
||
case "processing", "in_progress":
|
||
taskResult.Status = model.TaskStatusInProgress
|
||
case "completed":
|
||
taskResult.Status = model.TaskStatusSuccess
|
||
// Url intentionally left empty — the caller constructs the proxy URL using the public task ID
|
||
case "failed", "cancelled":
|
||
taskResult.Status = model.TaskStatusFailure
|
||
if resTask.Error != nil {
|
||
taskResult.Reason = resTask.Error.Message
|
||
} else {
|
||
taskResult.Reason = "task failed"
|
||
}
|
||
default:
|
||
}
|
||
if resTask.Progress > 0 && resTask.Progress < 100 {
|
||
taskResult.Progress = fmt.Sprintf("%d%%", resTask.Progress)
|
||
}
|
||
|
||
return &taskResult, nil
|
||
}
|
||
|
||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||
return task.Data, nil
|
||
}
|