mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-25 01:08:38 +00:00
feat(task): add adaptor billing interface and async settlement framework
Add three billing lifecycle methods to the TaskAdaptor interface: - EstimateBilling: compute OtherRatios from user request before pricing - AdjustBillingOnSubmit: adjust ratios from upstream submit response - AdjustBillingOnComplete: determine final quota at task terminal state Introduce BaseBilling as embeddable no-op default for adaptors without custom billing. Move Sora/Ali OtherRatios logic from shared validation into per-adaptor EstimateBilling implementations. Add TaskBillingContext to persist pricing params (model_price, group_ratio, other_ratios) in task private data for async polling settlement. Extract RecalculateTaskQuota as a general-purpose delta settlement function and unify polling billing via settleTaskBillingOnComplete (adaptor-first, then token-based fallback).
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/samber/lo"
|
||||
@@ -108,10 +109,10 @@ type AliMetadata struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
aliReq *AliVideoRequest
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// 阿里通义万相支持 JSON 格式,不使用 multipart
|
||||
var taskReq relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
||||
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
|
||||
}
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
a.aliReq = aliReq
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
bodyBytes, err := common.Marshal(a.aliReq)
|
||||
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get_task_request_failed")
|
||||
}
|
||||
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert_to_ali_request_failed")
|
||||
}
|
||||
logger.LogJson(c, "ali video request body", aliReq)
|
||||
|
||||
bodyBytes, err := common.Marshal(aliReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
||||
}
|
||||
|
||||
return bytes.NewReader(bodyBytes), nil
|
||||
}
|
||||
|
||||
@@ -335,19 +336,33 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
||||
return nil, errors.New("can't change model with metadata")
|
||||
}
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
return aliReq, nil
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
|
||||
// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
otherRatios := map[string]float64{
|
||||
"seconds": float64(aliReq.Parameters.Duration),
|
||||
}
|
||||
|
||||
ratios, err := ProcessAliOtherRatios(aliReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return otherRatios
|
||||
}
|
||||
for s, f := range ratios {
|
||||
info.PriceData.OtherRatios[s] = f
|
||||
for k, v := range ratios {
|
||||
otherRatios[k] = v
|
||||
}
|
||||
|
||||
return aliReq, nil
|
||||
return otherRatios
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper
|
||||
|
||||
@@ -89,6 +89,7 @@ type responseTask struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
|
||||
@@ -85,6 +85,7 @@ type operationResponse struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
)
|
||||
|
||||
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
|
||||
@@ -77,6 +77,7 @@ const (
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
|
||||
@@ -97,6 +97,7 @@ type responsePayload struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -56,6 +58,7 @@ type responseTask struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||
var req struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
var req relaycommon.TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
}
|
||||
// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
return relaycommon.ValidateMultipartDirect(c, info)
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := relaycommon.GetTaskRequest(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
seconds, _ := strconv.Atoi(req.Seconds)
|
||||
if seconds == 0 {
|
||||
seconds = req.Duration
|
||||
}
|
||||
if seconds <= 0 {
|
||||
seconds = 4
|
||||
}
|
||||
|
||||
size := req.Size
|
||||
if size == "" {
|
||||
size = "720x1280"
|
||||
}
|
||||
|
||||
ratios := map[string]float64{
|
||||
"seconds": float64(seconds),
|
||||
"size": 1,
|
||||
}
|
||||
if size == "1792x1024" || size == "1024x1792" {
|
||||
ratios["size"] = 1.666667
|
||||
}
|
||||
return ratios
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.Action == constant.TaskActionRemix {
|
||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
)
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
@@ -79,10 +81,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
sunoRequest, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("task_request not found in context")
|
||||
}
|
||||
data, err := common.Marshal(sunoRequest)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
|
||||
@@ -68,3 +71,25 @@ const (
|
||||
ProgressInProgress = "30%"
|
||||
ProgressComplete = "100%"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
|
||||
// Adaptors that do not need custom billing can embed this struct directly.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type BaseBilling struct{}
|
||||
|
||||
// EstimateBilling returns nil (no extra ratios; use base model price).
|
||||
func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
|
||||
func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
|
||||
func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -62,6 +62,7 @@ type operationResponse struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
return nil
|
||||
}
|
||||
|
||||
// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
|
||||
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||
sampleCount := 1
|
||||
v, ok := c.Get("task_request")
|
||||
if ok {
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
if req.Metadata != nil {
|
||||
if sc, exists := req.Metadata["sampleCount"]; exists {
|
||||
if i, ok := sc.(int); ok && i > 0 {
|
||||
sampleCount = i
|
||||
}
|
||||
if f, ok := sc.(float64); ok && int(f) > 0 {
|
||||
sampleCount = int(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return map[string]float64{
|
||||
"sampleCount": float64(sampleCount),
|
||||
}
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
@@ -166,24 +189,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, fmt.Errorf("sampleCount must be greater than 0")
|
||||
}
|
||||
|
||||
// if req.Duration > 0 {
|
||||
// body.Parameters["durationSeconds"] = req.Duration
|
||||
// } else if req.Seconds != "" {
|
||||
// seconds, err := strconv.Atoi(req.Seconds)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrap(err, "convert seconds to int failed")
|
||||
// }
|
||||
// body.Parameters["durationSeconds"] = seconds
|
||||
// }
|
||||
|
||||
info.PriceData.OtherRatios = map[string]float64{
|
||||
"sampleCount": float64(body.Parameters["sampleCount"].(int)),
|
||||
}
|
||||
|
||||
// if v, ok := body.Parameters["durationSeconds"]; ok {
|
||||
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
|
||||
// }
|
||||
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -73,6 +73,7 @@ type creation struct {
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
taskcommon.BaseBilling
|
||||
ChannelType int
|
||||
baseURL string
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user