diff --git a/controller/relay.go b/controller/relay.go index 132fee9ba..3d2f20e82 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -509,6 +509,13 @@ func RelayTask(c *gin.Context) { task.PrivateData.BillingSource = relayInfo.BillingSource task.PrivateData.SubscriptionId = relayInfo.SubscriptionId task.PrivateData.TokenId = relayInfo.TokenId + task.PrivateData.BillingContext = &model.TaskBillingContext{ + ModelPrice: relayInfo.PriceData.ModelPrice, + GroupRatio: relayInfo.PriceData.GroupRatioInfo.GroupRatio, + ModelRatio: relayInfo.PriceData.ModelRatio, + OtherRatios: relayInfo.PriceData.OtherRatios, + ModelName: result.ModelName, + } task.Quota = result.Quota task.Data = result.TaskData task.Action = relayInfo.Action diff --git a/logger/logger.go b/logger/logger.go index 61b1d49d8..90cf5006e 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,7 +2,6 @@ package logger import ( "context" - "encoding/json" "fmt" "io" "log" @@ -151,7 +150,7 @@ func FormatQuota(quota int) string { // LogJson 仅供测试使用 only for test func LogJson(ctx context.Context, msg string, obj any) { - jsonStr, err := json.Marshal(obj) + jsonStr, err := common.Marshal(obj) if err != nil { LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error())) return diff --git a/model/task.go b/model/task.go index 38bb4d05a..592643ebb 100644 --- a/model/task.go +++ b/model/task.go @@ -100,9 +100,19 @@ type TaskPrivateData struct { UpstreamTaskID string `json:"upstream_task_id,omitempty"` // 上游真实 task ID ResultURL string `json:"result_url,omitempty"` // 任务成功后的结果 URL(视频地址等) // 计费上下文:用于异步退款/差额结算(轮询阶段读取) - BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" - SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 - TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingSource string `json:"billing_source,omitempty"` // "wallet" 或 "subscription" + SubscriptionId int `json:"subscription_id,omitempty"` // 订阅 ID,用于订阅退款 + TokenId int `json:"token_id,omitempty"` // 令牌 ID,用于令牌额度退款 + BillingContext *TaskBillingContext `json:"billing_context,omitempty"` // 计费参数快照(用于轮询阶段重新计算) +} + +// TaskBillingContext 记录任务提交时的计费参数,以便轮询阶段可以重新计算额度。 +type TaskBillingContext struct { + ModelPrice float64 `json:"model_price,omitempty"` // 模型单价 + GroupRatio float64 `json:"group_ratio,omitempty"` // 分组倍率 + ModelRatio float64 `json:"model_ratio,omitempty"` // 模型倍率 + OtherRatios map[string]float64 `json:"other_ratios,omitempty"` // 附加倍率(时长、分辨率等) + ModelName string `json:"model_name,omitempty"` // 模型名称 } // GetUpstreamTaskID 获取上游真实 task ID(用于与 provider 通信) diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ff7606e2e..d2f7c6bb6 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -36,6 +36,32 @@ type TaskAdaptor interface { ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError + // ── Billing ────────────────────────────────────────────────────── + + // EstimateBilling returns OtherRatios for pre-charge based on user request. + // Called after ValidateRequestAndSetAction, before price calculation. + // Adaptors should extract duration, resolution, etc. from the parsed request + // and return them as ratio multipliers (e.g. {"seconds": 5, "size": 1.666}). + // Return nil to use the base model price without extra ratios. + EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 + + // AdjustBillingOnSubmit returns adjusted OtherRatios from the upstream + // submit response. Called after a successful DoResponse. + // If the upstream returned actual parameters that differ from the estimate + // (e.g. actual seconds), return updated ratios so the caller can recalculate + // the quota and settle the delta with the pre-charge. + // Return nil if no adjustment is needed. + AdjustBillingOnSubmit(info *relaycommon.RelayInfo, taskData []byte) map[string]float64 + + // AdjustBillingOnComplete returns the actual quota when a task reaches a + // terminal state (success/failure) during polling. + // Called by the polling loop after ParseTaskResult. + // Return a positive value to trigger delta settlement (supplement / refund). + // Return 0 to keep the pre-charged amount unchanged. + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int + + // ── Request / Response ─────────────────────────────────────────── + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) @@ -46,9 +72,9 @@ type TaskAdaptor interface { GetModelList() []string GetChannelName() string - // FetchTask - FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) + // ── Polling ────────────────────────────────────────────────────── + FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) } diff --git a/relay/channel/task/ali/adaptor.go b/relay/channel/task/ali/adaptor.go index 5d14ff655..f55178b3b 100644 --- a/relay/channel/task/ali/adaptor.go +++ b/relay/channel/task/ali/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/logger" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" "github.com/samber/lo" @@ -108,10 +109,10 @@ type AliMetadata struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string - aliReq *AliVideoRequest } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { @@ -121,17 +122,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // 阿里通义万相支持 JSON 格式,不使用 multipart - var taskReq relaycommon.TaskSubmitReq - if err := common.UnmarshalBodyReusable(c, &taskReq); err != nil { - return service.TaskErrorWrapper(err, "unmarshal_task_request_failed", http.StatusBadRequest) - } - aliReq, err := a.convertToAliRequest(info, taskReq) - if err != nil { - return service.TaskErrorWrapper(err, "convert_to_ali_request_failed", http.StatusInternalServerError) - } - a.aliReq = aliReq - logger.LogJson(c, "ali video request body", aliReq) + // ValidateMultipartDirect 负责解析并将原始 TaskSubmitReq 存入 context return relaycommon.ValidateMultipartDirect(c, info) } @@ -148,11 +139,21 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { - bodyBytes, err := common.Marshal(a.aliReq) + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil, errors.Wrap(err, "get_task_request_failed") + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil, errors.Wrap(err, "convert_to_ali_request_failed") + } + logger.LogJson(c, "ali video request body", aliReq) + + bodyBytes, err := common.Marshal(aliReq) if err != nil { return nil, errors.Wrap(err, "marshal_ali_request_failed") } - return bytes.NewReader(bodyBytes), nil } @@ -335,19 +336,33 @@ func (a *TaskAdaptor) convertToAliRequest(info *relaycommon.RelayInfo, req relay return nil, errors.New("can't change model with metadata") } - info.PriceData.OtherRatios = map[string]float64{ + return aliReq, nil +} + +// EstimateBilling 根据用户请求参数计算 OtherRatios(时长、分辨率等)。 +// 在 ValidateRequestAndSetAction 之后、价格计算之前调用。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + taskReq, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + aliReq, err := a.convertToAliRequest(info, taskReq) + if err != nil { + return nil + } + + otherRatios := map[string]float64{ "seconds": float64(aliReq.Parameters.Duration), } - ratios, err := ProcessAliOtherRatios(aliReq) if err != nil { - return nil, err + return otherRatios } - for s, f := range ratios { - info.PriceData.OtherRatios[s] = f + for k, v := range ratios { + otherRatios[k] = v } - - return aliReq, nil + return otherRatios } // DoRequest delegates to common helper diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 3da125afc..eca421bd3 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -89,6 +89,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/gemini/adaptor.go b/relay/channel/task/gemini/adaptor.go index a863ea852..06c00a469 100644 --- a/relay/channel/task/gemini/adaptor.go +++ b/relay/channel/task/gemini/adaptor.go @@ -85,6 +85,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/hailuo/adaptor.go b/relay/channel/task/hailuo/adaptor.go index 67a68a10e..ab83d659b 100644 --- a/relay/channel/task/hailuo/adaptor.go +++ b/relay/channel/task/hailuo/adaptor.go @@ -17,12 +17,14 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) // https://platform.minimaxi.com/docs/api-reference/video-generation-intro type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index 7f88be248..b61cca418 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -77,6 +77,7 @@ const ( // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int accessKey string secretKey string diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 4458626b2..46e210f19 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -97,6 +97,7 @@ type responsePayload struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string diff --git a/relay/channel/task/sora/adaptor.go b/relay/channel/task/sora/adaptor.go index ee69a3e48..8faaf984f 100644 --- a/relay/channel/task/sora/adaptor.go +++ b/relay/channel/task/sora/adaptor.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" "github.com/QuantumNous/new-api/common" @@ -11,6 +12,7 @@ import ( "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/model" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -56,6 +58,7 @@ type responseTask struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -68,15 +71,15 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } func validateRemixRequest(c *gin.Context) *dto.TaskError { - var req struct { - Prompt string `json:"prompt"` - } + var req relaycommon.TaskSubmitReq if err := common.UnmarshalBodyReusable(c, &req); err != nil { return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) } if strings.TrimSpace(req.Prompt) == "" { return service.TaskErrorWrapperLocal(fmt.Errorf("field prompt is required"), "invalid_request", http.StatusBadRequest) } + // 存储原始请求到 context,与 ValidateMultipartDirect 路径保持一致 + c.Set("task_request", req) return nil } @@ -87,6 +90,41 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return relaycommon.ValidateMultipartDirect(c, info) } +// EstimateBilling 根据用户请求的 seconds 和 size 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { + // remix 路径的 OtherRatios 已在 ResolveOriginTask 中设置 + if info.Action == constant.TaskActionRemix { + return nil + } + + req, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + + seconds, _ := strconv.Atoi(req.Seconds) + if seconds == 0 { + seconds = req.Duration + } + if seconds <= 0 { + seconds = 4 + } + + size := req.Size + if size == "" { + size = "720x1280" + } + + ratios := map[string]float64{ + "seconds": float64(seconds), + "size": 1, + } + if size == "1792x1024" || size == "1024x1792" { + ratios["size"] = 1.666667 + } + return ratios +} + func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { if info.Action == constant.TaskActionRemix { return fmt.Sprintf("%s/v1/videos/%s/remix", a.baseURL, info.OriginTaskID), nil diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index 5dd62a70f..2dbb44f00 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -13,6 +13,7 @@ import ( "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" + taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" @@ -20,6 +21,7 @@ import ( ) type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int } @@ -79,10 +81,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { - err := common.UnmarshalBodyReusable(c, &sunoRequest) - if err != nil { - return nil, err - } + return nil, fmt.Errorf("task_request not found in context") } data, err := common.Marshal(sunoRequest) if err != nil { diff --git a/relay/channel/task/taskcommon/helpers.go b/relay/channel/task/taskcommon/helpers.go index b1dde998b..27d6612d4 100644 --- a/relay/channel/task/taskcommon/helpers.go +++ b/relay/channel/task/taskcommon/helpers.go @@ -5,7 +5,10 @@ import ( "fmt" "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/setting/system_setting" + "github.com/gin-gonic/gin" ) // UnmarshalMetadata converts a map[string]any metadata to a typed struct via JSON round-trip. @@ -68,3 +71,25 @@ const ( ProgressInProgress = "30%" ProgressComplete = "100%" ) + +// --------------------------------------------------------------------------- +// BaseBilling — embeddable no-op implementations for TaskAdaptor billing methods. +// Adaptors that do not need custom billing can embed this struct directly. +// --------------------------------------------------------------------------- + +type BaseBilling struct{} + +// EstimateBilling returns nil (no extra ratios; use base model price). +func (BaseBilling) EstimateBilling(_ *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + return nil +} + +// AdjustBillingOnSubmit returns nil (no submit-time adjustment). +func (BaseBilling) AdjustBillingOnSubmit(_ *relaycommon.RelayInfo, _ []byte) map[string]float64 { + return nil +} + +// AdjustBillingOnComplete returns 0 (keep pre-charged amount). +func (BaseBilling) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { + return 0 +} diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index fb3a313ff..4931002dd 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -62,6 +62,7 @@ type operationResponse struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int apiKey string baseURL string @@ -133,6 +134,28 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } +// EstimateBilling 根据用户请求中的 sampleCount 计算 OtherRatios。 +func (a *TaskAdaptor) EstimateBilling(c *gin.Context, _ *relaycommon.RelayInfo) map[string]float64 { + sampleCount := 1 + v, ok := c.Get("task_request") + if ok { + req := v.(relaycommon.TaskSubmitReq) + if req.Metadata != nil { + if sc, exists := req.Metadata["sampleCount"]; exists { + if i, ok := sc.(int); ok && i > 0 { + sampleCount = i + } + if f, ok := sc.(float64); ok && int(f) > 0 { + sampleCount = int(f) + } + } + } + } + return map[string]float64{ + "sampleCount": float64(sampleCount), + } +} + // BuildRequestBody converts request into Vertex specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") @@ -166,24 +189,6 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn return nil, fmt.Errorf("sampleCount must be greater than 0") } - // if req.Duration > 0 { - // body.Parameters["durationSeconds"] = req.Duration - // } else if req.Seconds != "" { - // seconds, err := strconv.Atoi(req.Seconds) - // if err != nil { - // return nil, errors.Wrap(err, "convert seconds to int failed") - // } - // body.Parameters["durationSeconds"] = seconds - // } - - info.PriceData.OtherRatios = map[string]float64{ - "sampleCount": float64(body.Parameters["sampleCount"].(int)), - } - - // if v, ok := body.Parameters["durationSeconds"]; ok { - // info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int)) - // } - data, err := common.Marshal(body) if err != nil { return nil, err diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index 1bab12f03..e689bf888 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -73,6 +73,7 @@ type creation struct { // ============================ type TaskAdaptor struct { + taskcommon.BaseBilling ChannelType int baseURL string } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index b662f9053..3cbb18c22 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -173,16 +173,10 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) } - info.PriceData.OtherRatios = map[string]float64{ - "seconds": float64(seconds), - "size": 1, - } - if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { - info.PriceData.OtherRatios["size"] = 1.666667 - } + // OtherRatios 已移到 Sora adaptor 的 EstimateBilling 中设置 } - info.Action = action + storeTaskRequest(c, info, action, req) return nil } diff --git a/relay/relay_task.go b/relay/relay_task.go index d372ca2e8..7c6724d80 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -128,8 +128,9 @@ func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskErr } // RelayTaskSubmit 完成 task 提交的全部流程(每次尝试调用一次): -// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → 计算价格 → -// 预扣费(仅首次,通过 info.Billing==nil 守卫)→ 构建/发送/解析上游请求。 +// 刷新渠道元数据 → 确定 platform/adaptor → 验证请求 → +// 估算计费(EstimateBilling) → 计算价格 → 预扣费(仅首次)→ +// 构建/发送/解析上游请求 → 提交后计费调整(AdjustBillingOnSubmit)。 // 控制器负责 defer Refund 和成功后 Settle。 func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { info.InitChannelMeta(c) @@ -159,10 +160,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe info.PublicTaskID = model.GenerateTaskID() } - // 4. 价格计算 + // 4. 价格计算:基础模型价格 info.OriginModelName = modelName info.PriceData = helper.ModelPriceHelperPerCall(c, info) + // 5. 计费估算:让适配器根据用户请求提供 OtherRatios(时长、分辨率等) + // 必须在 ModelPriceHelperPerCall 之后调用(它会重建 PriceData)。 + // ResolveOriginTask 可能已在 remix 路径中预设了 OtherRatios,此处合并。 + if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 { + for k, v := range estimatedRatios { + info.PriceData.AddOtherRatio(k, v) + } + } + + // 6. 将 OtherRatios 应用到基础额度 if !common.StringsContains(constant.TaskPricePatches, modelName) { for _, ra := range info.PriceData.OtherRatios { if ra != 1.0 { @@ -171,7 +182,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe } } - // 5. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) + // 7. 预扣费(仅首次 — 重试时 info.Billing 已存在,跳过) if info.Billing == nil && !info.PriceData.FreeModel { info.ForcePreConsume = true if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { @@ -179,13 +190,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe } } - // 6. 构建请求体 + // 8. 构建请求体 requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) } - // 7. 发送请求 + // 9. 发送请求 resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) @@ -195,20 +206,59 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitRe return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) } - // 8. 解析响应 + // 10. 返回 OtherRatios 给下游(header 必须在 DoResponse 写 body 之前设置) + otherRatios := info.PriceData.OtherRatios + if otherRatios == nil { + otherRatios = map[string]float64{} + } + ratiosJSON, _ := common.Marshal(otherRatios) + c.Header("X-New-Api-Other-Ratios", string(ratiosJSON)) + + // 11. 解析响应 upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return nil, taskErr } + // 11. 提交后计费调整:让适配器根据上游实际返回调整 OtherRatios + finalQuota := info.PriceData.Quota + if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 { + // 基于调整后的 ratios 重新计算 quota + finalQuota = recalcQuotaFromRatios(info, adjustedRatios) + info.PriceData.OtherRatios = adjustedRatios + info.PriceData.Quota = finalQuota + } + return &TaskSubmitResult{ UpstreamTaskID: upstreamTaskID, TaskData: taskData, Platform: platform, ModelName: modelName, + Quota: finalQuota, }, nil } +// recalcQuotaFromRatios 根据 adjustedRatios 重新计算 quota。 +// 公式: baseQuota × ∏(ratio) — 其中 baseQuota 是不含 OtherRatios 的基础额度。 +func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int { + // 从 PriceData 获取不含 OtherRatios 的基础价格 + baseQuota := info.PriceData.Quota + // 先除掉原有的 OtherRatios 恢复基础额度 + for _, ra := range info.PriceData.OtherRatios { + if ra != 1.0 && ra > 0 { + baseQuota = int(float64(baseQuota) / ra) + } + } + // 应用新的 ratios + result := float64(baseQuota) + for _, ra := range ratios { + if ra != 1.0 { + result *= ra + } + } + return int(result) +} + var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, diff --git a/service/task_billing.go b/service/task_billing.go index ec0094bd9..fc44c5876 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -130,6 +130,58 @@ func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { model.RecordLog(task.UserId, model.LogTypeSystem, logContent) } +// RecalculateTaskQuota 通用的异步差额结算。 +// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。 +// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。 +func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) { + if actualQuota <= 0 { + return + } + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta == 0 { + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)", + task.TaskID, logger.LogQuota(actualQuota), reason)) + return + } + + logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + reason, + )) + + // 调整资金来源 + if err := taskAdjustFunding(task, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) + return + } + + // 调整令牌额度 + taskAdjustTokenQuota(ctx, task, quotaDelta) + + // 更新统计(仅补扣时更新,退还不影响已用统计) + if quotaDelta > 0 { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + task.Quota = actualQuota + + var action string + if quotaDelta > 0 { + action = "补扣费" + } else { + action = "退还" + } + logContent := fmt.Sprintf("异步任务成功%s,预扣费 %s,实际扣费 %s,原因:%s", + action, + logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota), reason) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) +} + // RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。 // 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度, // 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。 @@ -180,48 +232,6 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) - // 计算差额(正数=需要补扣,负数=需要退还) - preConsumedQuota := task.Quota - quotaDelta := actualQuota - preConsumedQuota - - if quotaDelta == 0 { - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", - task.TaskID, logger.LogQuota(actualQuota), totalTokens)) - return - } - - logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,tokens:%d)", - task.TaskID, - logger.LogQuota(quotaDelta), - logger.LogQuota(actualQuota), - logger.LogQuota(preConsumedQuota), - totalTokens, - )) - - // 调整资金来源 - if err := taskAdjustFunding(task, quotaDelta); err != nil { - logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) - return - } - - // 调整令牌额度 - taskAdjustTokenQuota(ctx, task, quotaDelta) - - // 更新统计(仅补扣时更新,退还不影响已用统计) - if quotaDelta > 0 { - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) - } - task.Quota = actualQuota - - var action string - if quotaDelta > 0 { - action = "补扣费" - } else { - action = "退还" - } - logContent := fmt.Sprintf("视频任务成功%s,模型倍率 %.2f,分组倍率 %.2f,tokens %d,预扣费 %s,实际扣费 %s", - action, modelRatio, finalGroupRatio, totalTokens, - logger.LogQuota(preConsumedQuota), logger.LogQuota(actualQuota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio) + RecalculateTaskQuota(ctx, task, actualQuota, reason) } diff --git a/service/task_polling.go b/service/task_polling.go index 847e1659b..efbad8afa 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -26,6 +26,9 @@ type TaskPollingAdaptor interface { Init(info *relaycommon.RelayInfo) FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) + // AdjustBillingOnComplete 在任务到达终态(成功/失败)时由轮询循环调用。 + // 返回正数触发差额结算(补扣/退还),返回 0 保持预扣费金额不变。 + AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int } // GetTaskAdaptorFunc 由 main 包注入,用于获取指定平台的任务适配器。 @@ -372,10 +375,8 @@ func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch * task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) } - // 如果返回了 total_tokens,根据模型倍率重新计费 - if taskResult.TotalTokens > 0 { - RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) - } + // 完成时计费调整:优先由 adaptor 计算,回退到 token 重算 + settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) case model.TaskStatusFailure: logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) task.Status = model.TaskStatusFailure @@ -444,3 +445,22 @@ func truncateBase64(s string) string { } return s[:maxKeep] + "..." } + +// settleTaskBillingOnComplete 任务完成时的统一计费调整。 +// 优先级:1. adaptor.AdjustBillingOnComplete 返回正数 → 使用 adaptor 计算的额度 +// +// 2. taskResult.TotalTokens > 0 → 按 token 重算 +// 3. 都不满足 → 保持预扣额度不变 +func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { + // 1. 优先让 adaptor 决定最终额度 + if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") + return + } + // 2. 回退到 token 重算 + if taskResult.TotalTokens > 0 { + RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) + return + } + // 3. 无调整,保持预扣额度 +}