mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-19 09:28:37 +00:00
feat(task): add adaptor billing interface and async settlement framework
Add three billing lifecycle methods to the TaskAdaptor interface: - EstimateBilling: compute OtherRatios from user request before pricing - AdjustBillingOnSubmit: adjust ratios from upstream submit response - AdjustBillingOnComplete: determine final quota at task terminal state Introduce BaseBilling as embeddable no-op default for adaptors without custom billing. Move Sora/Ali OtherRatios logic from shared validation into per-adaptor EstimateBilling implementations. Add TaskBillingContext to persist pricing params (model_price, group_ratio, other_ratios) in task private data for async polling settlement. Extract RecalculateTaskQuota as a general-purpose delta settlement function and unify polling billing via settleTaskBillingOnComplete (adaptor-first, then token-based fallback).
This commit is contained in:
@@ -509,6 +509,13 @@ func RelayTask(c *gin.Context) {
|
|||||||
task.PrivateData.BillingSource = relayInfo.BillingSource
|
task.PrivateData.BillingSource = relayInfo.BillingSource
|
||||||
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
|
task.PrivateData.SubscriptionId = relayInfo.SubscriptionId
|
||||||
task.PrivateData.TokenId = relayInfo.TokenId
|
task.PrivateData.TokenId = relayInfo.TokenId
|
||||||
|
task.PrivateData.BillingContext = &model.TaskBillingContext{
|
||||||
|
ModelPrice: relayInfo.PriceData.ModelPrice,
|
||||||
|
GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio,
|
||||||
|
ModelRatio: relayInfo.PriceData.ModelRatio,
|
||||||
|
OtherRatios: relayInfo.PriceData.OtherRatios,
|
||||||
|
ModelName: result.ModelName,
|
||||||
|
}
|
||||||
task.Quota = result.Quota
|
task.Quota = result.Quota
|
||||||
task.Data = result.TaskData
|
task.Data = result.TaskData
|
||||||
task.Action = relayInfo.Action
|
task.Action = relayInfo.Action
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package logger
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -151,7 +150,7 @@ func FormatQuota(quota int) string {
|
|||||||
|
|
||||||
// LogJson 仅供测试使用 only for test
|
// LogJson 仅供测试使用 only for test
|
||||||
func LogJson(ctx context.Context, msg string, obj any) {
|
func LogJson(ctx context.Context, msg string, obj any) {
|
||||||
jsonStr, err := json.Marshal(obj)
|
jsonStr, err := common.Marshal(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -100,9 +100,19 @@ type TaskPrivateData struct {
|
|||||||
UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
|
UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID
|
||||||
ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等)
|
ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等)
|
||||||
// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
|
// 计费上下文:用于异步退款/差额结算(轮询阶段读取)
|
||||||
BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription"
|
BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription"
|
||||||
SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
|
SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款
|
||||||
TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款
|
TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款
|
||||||
|
BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。
|
||||||
|
type TaskBillingContext struct {
|
||||||
|
ModelPrice float64 `json:"model_price,omitempty"` // 模型单价
|
||||||
|
GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率
|
||||||
|
ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率
|
||||||
|
OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等)
|
||||||
|
ModelName string `json:"model_name,omitempty"` // 模型名称
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
|
// GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信)
|
||||||
|
|||||||
@@ -36,6 +36,32 @@ type TaskAdaptor interface {
|
|||||||
|
|
||||||
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError
|
||||||
|
|
||||||
|
// ── Billing ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// EstimateBilling returns OtherRatios for pre-charge based on user request.
|
||||||
|
// Called after ValidateRequestAndSetAction, before price calculation.
|
||||||
|
// Adaptors should extract duration, resolution, etc. from the parsed request
|
||||||
|
// and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}).
|
||||||
|
// Return nil to use the base model price without extra ratios.
|
||||||
|
EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64
|
||||||
|
|
||||||
|
// AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream
|
||||||
|
// submit response. Called after a successful DoResponse.
|
||||||
|
// If the upstream returned actual parameters that differ from the estimate
|
||||||
|
// (e.g. actual seconds), return updated ratios so the caller can recalculate
|
||||||
|
// the quota and settle the delta with the pre-charge.
|
||||||
|
// Return nil if no adjustment is needed.
|
||||||
|
AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64
|
||||||
|
|
||||||
|
// AdjustBillingOnComplete returns the actual quota when a task reaches a
|
||||||
|
// terminal state (success/failure) during polling.
|
||||||
|
// Called by the polling loop after ParseTaskResult.
|
||||||
|
// Return a positive value to trigger delta settlement (supplement / refund).
|
||||||
|
// Return 0 to keep the pre-charged amount unchanged.
|
||||||
|
AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
|
||||||
|
|
||||||
|
// ── Request / Response ───────────────────────────────────────────
|
||||||
|
|
||||||
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
BuildRequestURL(info *relaycommon.RelayInfo) (string, error)
|
||||||
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
|
||||||
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error)
|
||||||
@@ -46,9 +72,9 @@ type TaskAdaptor interface {
|
|||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
GetChannelName() string
|
GetChannelName() string
|
||||||
|
|
||||||
// FetchTask
|
// ── Polling ──────────────────────────────────────────────────────
|
||||||
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
|
||||||
|
|
||||||
|
FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/logger"
|
"github.com/QuantumNous/new-api/logger"
|
||||||
"github.com/QuantumNous/new-api/model"
|
"github.com/QuantumNous/new-api/model"
|
||||||
"github.com/QuantumNous/new-api/relay/channel"
|
"github.com/QuantumNous/new-api/relay/channel"
|
||||||
|
"github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
@@ -108,10 +109,10 @@ type AliMetadata struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
aliReq *AliVideoRequest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
@@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||||
// 阿里通义万相支持 JSON 格式,不使用 multipart
|
// ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context
|
||||||
var taskReq relaycommon.TaskSubmitReq
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil {
|
|
||||||
return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
aliReq, err := a.convertToAliRequest(info, taskReq)
|
|
||||||
if err != nil {
|
|
||||||
return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
a.aliReq = aliReq
|
|
||||||
logger.LogJson(c, "ali video request body", aliReq)
|
|
||||||
return relaycommon.ValidateMultipartDirect(c, info)
|
return relaycommon.ValidateMultipartDirect(c, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||||
bodyBytes, err := common.Marshal(a.aliReq)
|
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "get_task_request_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert_to_ali_request_failed")
|
||||||
|
}
|
||||||
|
logger.LogJson(c, "ali video request body", aliReq)
|
||||||
|
|
||||||
|
bodyBytes, err := common.Marshal(aliReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
return nil, errors.Wrap(err, "marshal_ali_request_failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
return bytes.NewReader(bodyBytes), nil
|
return bytes.NewReader(bodyBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,19 +336,33 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay
|
|||||||
return nil, errors.New("can't change model with metadata")
|
return nil, errors.New("can't change model with metadata")
|
||||||
}
|
}
|
||||||
|
|
||||||
info.PriceData.OtherRatios = map[string]float64{
|
return aliReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。
|
||||||
|
// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。
|
||||||
|
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||||
|
taskReq, err := relaycommon.GetTaskRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
aliReq, err := a.convertToAliRequest(info, taskReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
otherRatios := map[string]float64{
|
||||||
"seconds": float64(aliReq.Parameters.Duration),
|
"seconds": float64(aliReq.Parameters.Duration),
|
||||||
}
|
}
|
||||||
|
|
||||||
ratios, err := ProcessAliOtherRatios(aliReq)
|
ratios, err := ProcessAliOtherRatios(aliReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return otherRatios
|
||||||
}
|
}
|
||||||
for s, f := range ratios {
|
for k, v := range ratios {
|
||||||
info.PriceData.OtherRatios[s] = f
|
otherRatios[k] = v
|
||||||
}
|
}
|
||||||
|
return otherRatios
|
||||||
return aliReq, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoRequest delegates to common helper
|
// DoRequest delegates to common helper
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ type responseTask struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ type operationResponse struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/relay/channel"
|
"github.com/QuantumNous/new-api/relay/channel"
|
||||||
|
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
|
// https://platform.minimaxi.com/docs/api-reference/video-generation-intro
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ const (
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
accessKey string
|
accessKey string
|
||||||
secretKey string
|
secretKey string
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ type responsePayload struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
@@ -11,6 +12,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/model"
|
"github.com/QuantumNous/new-api/model"
|
||||||
"github.com/QuantumNous/new-api/relay/channel"
|
"github.com/QuantumNous/new-api/relay/channel"
|
||||||
|
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
|
||||||
@@ -56,6 +58,7 @@ type responseTask struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
@@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
func validateRemixRequest(c *gin.Context) *dto.TaskError {
|
||||||
var req struct {
|
var req relaycommon.TaskSubmitReq
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
}
|
|
||||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||||
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(req.Prompt) == "" {
|
if strings.TrimSpace(req.Prompt) == "" {
|
||||||
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
// 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致
|
||||||
|
c.Set("task_request", req)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
|||||||
return relaycommon.ValidateMultipartDirect(c, info)
|
return relaycommon.ValidateMultipartDirect(c, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。
|
||||||
|
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 {
|
||||||
|
// remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置
|
||||||
|
if info.Action == constant.TaskActionRemix {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := relaycommon.GetTaskRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
seconds, _ := strconv.Atoi(req.Seconds)
|
||||||
|
if seconds == 0 {
|
||||||
|
seconds = req.Duration
|
||||||
|
}
|
||||||
|
if seconds <= 0 {
|
||||||
|
seconds = 4
|
||||||
|
}
|
||||||
|
|
||||||
|
size := req.Size
|
||||||
|
if size == "" {
|
||||||
|
size = "720x1280"
|
||||||
|
}
|
||||||
|
|
||||||
|
ratios := map[string]float64{
|
||||||
|
"seconds": float64(seconds),
|
||||||
|
"size": 1,
|
||||||
|
}
|
||||||
|
if size == "1792x1024" || size == "1024x1792" {
|
||||||
|
ratios["size"] = 1.666667
|
||||||
|
}
|
||||||
|
return ratios
|
||||||
|
}
|
||||||
|
|
||||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
if info.Action == constant.TaskActionRemix {
|
if info.Action == constant.TaskActionRemix {
|
||||||
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/QuantumNous/new-api/constant"
|
"github.com/QuantumNous/new-api/constant"
|
||||||
"github.com/QuantumNous/new-api/dto"
|
"github.com/QuantumNous/new-api/dto"
|
||||||
"github.com/QuantumNous/new-api/relay/channel"
|
"github.com/QuantumNous/new-api/relay/channel"
|
||||||
|
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/service"
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
|
||||||
@@ -20,6 +21,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,10 +81,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
|||||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||||
sunoRequest, ok := c.Get("task_request")
|
sunoRequest, ok := c.Get("task_request")
|
||||||
if !ok {
|
if !ok {
|
||||||
err := common.UnmarshalBodyReusable(c, &sunoRequest)
|
return nil, fmt.Errorf("task_request not found in context")
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
data, err := common.Marshal(sunoRequest)
|
data, err := common.Marshal(sunoRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/QuantumNous/new-api/common"
|
"github.com/QuantumNous/new-api/common"
|
||||||
|
"github.com/QuantumNous/new-api/model"
|
||||||
|
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
|
// UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip.
|
||||||
@@ -68,3 +71,25 @@ const (
|
|||||||
ProgressInProgress = "30%"
|
ProgressInProgress = "30%"
|
||||||
ProgressComplete = "100%"
|
ProgressComplete = "100%"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods.
|
||||||
|
// Adaptors that do not need custom billing can embed this struct directly.
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type BaseBilling struct{}
|
||||||
|
|
||||||
|
// EstimateBilling returns nil (no extra ratios; use base model price).
|
||||||
|
func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdjustBillingOnSubmit returns nil (no submit-time adjustment).
|
||||||
|
func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdjustBillingOnComplete returns 0 (keep pre-charged amount).
|
||||||
|
func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ type operationResponse struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
apiKey string
|
apiKey string
|
||||||
baseURL string
|
baseURL string
|
||||||
@@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。
|
||||||
|
func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 {
|
||||||
|
sampleCount := 1
|
||||||
|
v, ok := c.Get("task_request")
|
||||||
|
if ok {
|
||||||
|
req := v.(relaycommon.TaskSubmitReq)
|
||||||
|
if req.Metadata != nil {
|
||||||
|
if sc, exists := req.Metadata["sampleCount"]; exists {
|
||||||
|
if i, ok := sc.(int); ok && i > 0 {
|
||||||
|
sampleCount = i
|
||||||
|
}
|
||||||
|
if f, ok := sc.(float64); ok && int(f) > 0 {
|
||||||
|
sampleCount = int(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]float64{
|
||||||
|
"sampleCount": float64(sampleCount),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// BuildRequestBody converts request into Vertex specific format.
|
// BuildRequestBody converts request into Vertex specific format.
|
||||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||||
v, ok := c.Get("task_request")
|
v, ok := c.Get("task_request")
|
||||||
@@ -166,24 +189,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
return nil, fmt.Errorf("sampleCount must be greater than 0")
|
return nil, fmt.Errorf("sampleCount must be greater than 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if req.Duration > 0 {
|
|
||||||
// body.Parameters["durationSeconds"] = req.Duration
|
|
||||||
// } else if req.Seconds != "" {
|
|
||||||
// seconds, err := strconv.Atoi(req.Seconds)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, errors.Wrap(err, "convert seconds to int failed")
|
|
||||||
// }
|
|
||||||
// body.Parameters["durationSeconds"] = seconds
|
|
||||||
// }
|
|
||||||
|
|
||||||
info.PriceData.OtherRatios = map[string]float64{
|
|
||||||
"sampleCount": float64(body.Parameters["sampleCount"].(int)),
|
|
||||||
}
|
|
||||||
|
|
||||||
// if v, ok := body.Parameters["durationSeconds"]; ok {
|
|
||||||
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
|
|
||||||
// }
|
|
||||||
|
|
||||||
data, err := common.Marshal(body)
|
data, err := common.Marshal(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ type creation struct {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
type TaskAdaptor struct {
|
type TaskAdaptor struct {
|
||||||
|
taskcommon.BaseBilling
|
||||||
ChannelType int
|
ChannelType int
|
||||||
baseURL string
|
baseURL string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
|||||||
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
|
if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) {
|
||||||
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true)
|
||||||
}
|
}
|
||||||
info.PriceData.OtherRatios = map[string]float64{
|
// OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置
|
||||||
"seconds": float64(seconds),
|
|
||||||
"size": 1,
|
|
||||||
}
|
|
||||||
if lo.Contains([]string{"1792x1024", "1024x1792"}, size) {
|
|
||||||
info.PriceData.OtherRatios["size"] = 1.666667
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info.Action = action
|
storeTaskRequest(c, info, action, req)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,8 +128,9 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
|
// RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次):
|
||||||
// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 →
|
// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 →
|
||||||
// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。
|
// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→
|
||||||
|
// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。
|
||||||
// 控制器负责 defer Refund 和成功后 Settle。
|
// 控制器负责 defer Refund 和成功后 Settle。
|
||||||
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
|
func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) {
|
||||||
info.InitChannelMeta(c)
|
info.InitChannelMeta(c)
|
||||||
@@ -159,10 +160,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|||||||
info.PublicTaskID = model.GenerateTaskID()
|
info.PublicTaskID = model.GenerateTaskID()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 价格计算
|
// 4. 价格计算:基础模型价格
|
||||||
info.OriginModelName = modelName
|
info.OriginModelName = modelName
|
||||||
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
|
info.PriceData = helper.ModelPriceHelperPerCall(c, info)
|
||||||
|
|
||||||
|
// 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等)
|
||||||
|
// 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。
|
||||||
|
// ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。
|
||||||
|
if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 {
|
||||||
|
for k, v := range estimatedRatios {
|
||||||
|
info.PriceData.AddOtherRatio(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. 将 OtherRatios 应用到基础额度
|
||||||
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||||
for _, ra := range info.PriceData.OtherRatios {
|
for _, ra := range info.PriceData.OtherRatios {
|
||||||
if ra != 1.0 {
|
if ra != 1.0 {
|
||||||
@@ -171,7 +182,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
// 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过)
|
||||||
if info.Billing == nil && !info.PriceData.FreeModel {
|
if info.Billing == nil && !info.PriceData.FreeModel {
|
||||||
info.ForcePreConsume = true
|
info.ForcePreConsume = true
|
||||||
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
|
if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil {
|
||||||
@@ -179,13 +190,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. 构建请求体
|
// 8. 构建请求体
|
||||||
requestBody, err := adaptor.BuildRequestBody(c, info)
|
requestBody, err := adaptor.BuildRequestBody(c, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. 发送请求
|
// 9. 发送请求
|
||||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
@@ -195,20 +206,59 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe
|
|||||||
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. 解析响应
|
// 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置)
|
||||||
|
otherRatios := info.PriceData.OtherRatios
|
||||||
|
if otherRatios == nil {
|
||||||
|
otherRatios = map[string]float64{}
|
||||||
|
}
|
||||||
|
ratiosJSON, _ := common.Marshal(otherRatios)
|
||||||
|
c.Header("X-New-Api-Other-Ratios", string(ratiosJSON))
|
||||||
|
|
||||||
|
// 11. 解析响应
|
||||||
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info)
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
return nil, taskErr
|
return nil, taskErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios
|
||||||
|
finalQuota := info.PriceData.Quota
|
||||||
|
if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 {
|
||||||
|
// 基于调整后的 ratios 重新计算 quota
|
||||||
|
finalQuota = recalcQuotaFromRatios(info, adjustedRatios)
|
||||||
|
info.PriceData.OtherRatios = adjustedRatios
|
||||||
|
info.PriceData.Quota = finalQuota
|
||||||
|
}
|
||||||
|
|
||||||
return &TaskSubmitResult{
|
return &TaskSubmitResult{
|
||||||
UpstreamTaskID: upstreamTaskID,
|
UpstreamTaskID: upstreamTaskID,
|
||||||
TaskData: taskData,
|
TaskData: taskData,
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
|
Quota: finalQuota,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。
|
||||||
|
// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。
|
||||||
|
func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int {
|
||||||
|
// 从 PriceData 获取不含 OtherRatios 的基础价格
|
||||||
|
baseQuota := info.PriceData.Quota
|
||||||
|
// 先除掉原有的 OtherRatios 恢复基础额度
|
||||||
|
for _, ra := range info.PriceData.OtherRatios {
|
||||||
|
if ra != 1.0 && ra > 0 {
|
||||||
|
baseQuota = int(float64(baseQuota) / ra)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 应用新的 ratios
|
||||||
|
result := float64(baseQuota)
|
||||||
|
for _, ra := range ratios {
|
||||||
|
if ra != 1.0 {
|
||||||
|
result *= ra
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return int(result)
|
||||||
|
}
|
||||||
|
|
||||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||||
|
|||||||
@@ -130,6 +130,58 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
|
|||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecalculateTaskQuota 通用的异步差额结算。
|
||||||
|
// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
|
||||||
|
// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
|
||||||
|
func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
|
||||||
|
if actualQuota <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preConsumedQuota := task.Quota
|
||||||
|
quotaDelta := actualQuota - preConsumedQuota
|
||||||
|
|
||||||
|
if quotaDelta == 0 {
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
|
||||||
|
task.TaskID, logger.LogQuota(actualQuota), reason))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
|
||||||
|
task.TaskID,
|
||||||
|
logger.LogQuota(quotaDelta),
|
||||||
|
logger.LogQuota(actualQuota),
|
||||||
|
logger.LogQuota(preConsumedQuota),
|
||||||
|
reason,
|
||||||
|
))
|
||||||
|
|
||||||
|
// 调整资金来源
|
||||||
|
if err := taskAdjustFunding(task, quotaDelta); err != nil {
|
||||||
|
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调整令牌额度
|
||||||
|
taskAdjustTokenQuota(ctx, task, quotaDelta)
|
||||||
|
|
||||||
|
// 更新统计(仅补扣时更新,退还不影响已用统计)
|
||||||
|
if quotaDelta > 0 {
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||||||
|
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||||||
|
}
|
||||||
|
task.Quota = actualQuota
|
||||||
|
|
||||||
|
var action string
|
||||||
|
if quotaDelta > 0 {
|
||||||
|
action = "补扣费"
|
||||||
|
} else {
|
||||||
|
action = "退还"
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s",
|
||||||
|
action,
|
||||||
|
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason)
|
||||||
|
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||||
|
}
|
||||||
|
|
||||||
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
|
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
|
||||||
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
|
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
|
||||||
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
|
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
|
||||||
@@ -180,48 +232,6 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo
|
|||||||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||||||
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
|
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
|
||||||
|
|
||||||
// 计算差额(正数=需要补扣,负数=需要退还)
|
reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
|
||||||
preConsumedQuota := task.Quota
|
RecalculateTaskQuota(ctx, task, actualQuota, reason)
|
||||||
quotaDelta := actualQuota - preConsumedQuota
|
|
||||||
|
|
||||||
if quotaDelta == 0 {
|
|
||||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)",
|
|
||||||
task.TaskID, logger.LogQuota(actualQuota), totalTokens))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)",
|
|
||||||
task.TaskID,
|
|
||||||
logger.LogQuota(quotaDelta),
|
|
||||||
logger.LogQuota(actualQuota),
|
|
||||||
logger.LogQuota(preConsumedQuota),
|
|
||||||
totalTokens,
|
|
||||||
))
|
|
||||||
|
|
||||||
// 调整资金来源
|
|
||||||
if err := taskAdjustFunding(task, quotaDelta); err != nil {
|
|
||||||
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 调整令牌额度
|
|
||||||
taskAdjustTokenQuota(ctx, task, quotaDelta)
|
|
||||||
|
|
||||||
// 更新统计(仅补扣时更新,退还不影响已用统计)
|
|
||||||
if quotaDelta > 0 {
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
|
||||||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
|
||||||
}
|
|
||||||
task.Quota = actualQuota
|
|
||||||
|
|
||||||
var action string
|
|
||||||
if quotaDelta > 0 {
|
|
||||||
action = "补扣费"
|
|
||||||
} else {
|
|
||||||
action = "退还"
|
|
||||||
}
|
|
||||||
logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s",
|
|
||||||
action, modelRatio, finalGroupRatio, totalTokens,
|
|
||||||
logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota))
|
|
||||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ type TaskPollingAdaptor interface {
|
|||||||
Init(info *relaycommon.RelayInfo)
|
Init(info *relaycommon.RelayInfo)
|
||||||
FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
|
FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error)
|
||||||
ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
|
ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error)
|
||||||
|
// AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。
|
||||||
|
// 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。
|
||||||
|
AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
|
// GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。
|
||||||
@@ -372,10 +375,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *
|
|||||||
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
|
task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果返回了 total_tokens,根据模型倍率重新计费
|
// 完成时计费调整:优先由 adaptor 计算,回退到 token 重算
|
||||||
if taskResult.TotalTokens > 0 {
|
settleTaskBillingOnComplete(ctx, adaptor, task, taskResult)
|
||||||
RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
|
|
||||||
}
|
|
||||||
case model.TaskStatusFailure:
|
case model.TaskStatusFailure:
|
||||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||||
task.Status = model.TaskStatusFailure
|
task.Status = model.TaskStatusFailure
|
||||||
@@ -444,3 +445,22 @@ func truncateBase64(s string) string {
|
|||||||
}
|
}
|
||||||
return s[:maxKeep] + "..."
|
return s[:maxKeep] + "..."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// settleTaskBillingOnComplete 任务完成时的统一计费调整。
|
||||||
|
// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度
|
||||||
|
//
|
||||||
|
// 2. taskResult.TotalTokens > 0 → 按 token 重算
|
||||||
|
// 3. 都不满足 → 保持预扣额度不变
|
||||||
|
func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) {
|
||||||
|
// 1. 优先让 adaptor 决定最终额度
|
||||||
|
if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 {
|
||||||
|
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 2. 回退到 token 重算
|
||||||
|
if taskResult.TotalTokens > 0 {
|
||||||
|
RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 3. 无调整,保持预扣额度
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user