feat(task): add adaptor billing interface and async settlement framework

Add three billing lifecycle methods to the TaskAdaptor interface:
- EstimateBilling: compute OtherRatios from user request before pricing
- AdjustBillingOnSubmit: adjust ratios from upstream submit response
- AdjustBillingOnComplete: determine final quota at task terminal state

Introduce BaseBilling as embeddable no-op default for adaptors without
custom billing. Move Sora/Ali OtherRatios logic from shared validation
into per-adaptor EstimateBilling implementations.

Add TaskBillingContext to persist pricing params (model_price, group_ratio,
other_ratios) in task private data for async polling settlement.

Extract RecalculateTaskQuota as a general-purpose delta settlement
function and unify polling billing via settleTaskBillingOnComplete
(adaptor-first, then token-based fallback).
This commit is contained in:
CaIon
2026-02-10 21:15:09 +08:00
parent 9e3954428d
commit d6e11fd2e1
19 changed files with 321 additions and 116 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"io"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
@@ -11,6 +12,7 @@ import (
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel"
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
@@ -56,6 +58,7 @@ type responseTask struct {
// ============================
type TaskAdaptor struct {
taskcommon.BaseBilling
ChannelType int
apiKey string
baseURL string
@@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
}
func validateRemixRequest(c *gin.Context) *dto.TaskError {
var req struct {
Prompt string `json:"prompt"`
}
var req relaycommon.TaskSubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
}
if strings.TrimSpace(req.Prompt) == "" {
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
}
// 存储原始请求到 context与 ValidateMultipartDirect 路径保持一致
c.Set("task_request", req)
return nil
}
@@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
return relaycommon.ValidateMultipartDirect(c, info)
}
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
if info.Action == constant.TaskActionRemix {
return nil
}
req, err := relaycommon.GetTaskRequest(c)
if err != nil {
return nil
}
seconds, _ := strconv.Atoi(req.Seconds)
if seconds == 0 {
seconds = req.Duration
}
if seconds <= 0 {
seconds = 4
}
size := req.Size
if size == "" {
size = "720x1280"
}
ratios := map[string]float64{
"seconds": float64(seconds),
"size": 1,
}
if size == "1792x1024" || size == "1024x1792" {
ratios["size"] = 1.666667
}
return ratios
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.Action == constant.TaskActionRemix {
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil