mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 06:20:56 +00:00
1. Async task model redirection (aligned with sync tasks):
- Integrate ModelMappedHelper in RelayTaskSubmit after model name
determination, populating OriginModelName / UpstreamModelName on RelayInfo.
- All task adaptors now send UpstreamModelName to upstream providers:
- Gemini & Vertex: BuildRequestURL uses UpstreamModelName.
- Doubao & Ali: BuildRequestBody conditionally overwrites body.Model.
- Vidu, Kling, Hailuo, Jimeng: convertToRequestPayload accepts RelayInfo
and unconditionally uses info.UpstreamModelName.
- Sora: BuildRequestBody parses JSON and multipart bodies to replace
the "model" field with UpstreamModelName.
- Frontend log visibility: LogTaskConsumption and taskBillingOther now
emit is_model_mapped / upstream_model_name in the "other" JSON field.
- Billing safety: RecalculateTaskQuotaByTokens reads model name from
BillingContext.OriginModelName (via taskModelName) instead of
task.Data["model"], preventing billing leaks from upstream model names.
2. Per-call billing (TaskPricePatches lifecycle):
- Rename TaskBillingContext.ModelName → OriginModelName; add PerCallBilling
bool field, populated from TaskPricePatches at submission time.
- settleTaskBillingOnComplete short-circuits when PerCallBilling is true,
skipping both adaptor adjustments and token-based recalculation.
- Remove ModelName from TaskSubmitResult; use relayInfo.OriginModelName
consistently in controller/relay.go for billing context and logging.
3. Multipart retry boundary mismatch fix:
- Root cause: after Sora (or OpenAI audio) rebuilds a multipart body with a
new boundary and overwrites c.Request.Header["Content-Type"], subsequent
calls to ParseMultipartFormReusable on retry would parse the cached
original body with the wrong boundary, causing "NextPart: EOF".
- Fix: ParseMultipartFormReusable now caches the original Content-Type in
gin context key "_original_multipart_ct" on first call and reuses it for
all subsequent parses, making multipart parsing retry-safe globally.
- Sora adaptor reverted to the standard pattern (direct header set/get),
which is now safe thanks to the root fix.
4. Tests:
- task_billing_test.go: update makeTask to use OriginModelName; add
PerCallBilling settlement tests (skip adaptor adjust, skip token recalc);
add non-per-call adaptor adjustment test with refund verification.
286 lines
9.4 KiB
Go
286 lines
9.4 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
"github.com/QuantumNous/new-api/constant"
|
||
"github.com/QuantumNous/new-api/logger"
|
||
"github.com/QuantumNous/new-api/model"
|
||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||
"github.com/QuantumNous/new-api/setting/ratio_setting"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
|
||
// 实际扣费已由 BillingSession(PreConsumeBilling + SettleBilling)完成。
|
||
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
|
||
tokenName := c.GetString("token_name")
|
||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||
// 支持任务仅按次计费
|
||
if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
|
||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||
} else {
|
||
if len(info.PriceData.OtherRatios) > 0 {
|
||
var contents []string
|
||
for key, ra := range info.PriceData.OtherRatios {
|
||
if 1.0 != ra {
|
||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||
}
|
||
}
|
||
if len(contents) > 0 {
|
||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||
}
|
||
}
|
||
}
|
||
other := make(map[string]interface{})
|
||
other["request_path"] = c.Request.URL.Path
|
||
other["model_price"] = info.PriceData.ModelPrice
|
||
other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
|
||
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
|
||
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
|
||
}
|
||
if info.IsModelMapped {
|
||
other["is_model_mapped"] = true
|
||
other["upstream_model_name"] = info.UpstreamModelName
|
||
}
|
||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||
ChannelId: info.ChannelId,
|
||
ModelName: info.OriginModelName,
|
||
TokenName: tokenName,
|
||
Quota: info.PriceData.Quota,
|
||
Content: logContent,
|
||
TokenId: info.TokenId,
|
||
Group: info.UsingGroup,
|
||
Other: other,
|
||
})
|
||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
|
||
model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// 异步任务计费辅助函数
|
||
// ---------------------------------------------------------------------------
|
||
|
||
// resolveTokenKey 通过 TokenId 运行时获取令牌 Key(用于 Redis 缓存操作)。
|
||
// 如果令牌已被删除或查询失败,返回空字符串。
|
||
func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
|
||
token, err := model.GetTokenById(tokenId)
|
||
if err != nil {
|
||
logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
|
||
return ""
|
||
}
|
||
return token.Key
|
||
}
|
||
|
||
// taskIsSubscription 判断任务是否通过订阅计费。
|
||
func taskIsSubscription(task *model.Task) bool {
|
||
return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
|
||
}
|
||
|
||
// taskAdjustFunding 调整任务的资金来源(钱包或订阅),delta > 0 表示扣费,delta < 0 表示退还。
|
||
func taskAdjustFunding(task *model.Task, delta int) error {
|
||
if taskIsSubscription(task) {
|
||
return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
|
||
}
|
||
if delta > 0 {
|
||
return model.DecreaseUserQuota(task.UserId, delta)
|
||
}
|
||
return model.IncreaseUserQuota(task.UserId, -delta, false)
|
||
}
|
||
|
||
// taskAdjustTokenQuota 调整任务的令牌额度,delta > 0 表示扣费,delta < 0 表示退还。
|
||
// 需要通过 resolveTokenKey 运行时获取 key(不从 PrivateData 中读取)。
|
||
func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
|
||
if task.PrivateData.TokenId <= 0 || delta == 0 {
|
||
return
|
||
}
|
||
tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
|
||
if tokenKey == "" {
|
||
return
|
||
}
|
||
var err error
|
||
if delta > 0 {
|
||
err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
|
||
} else {
|
||
err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
|
||
}
|
||
if err != nil {
|
||
logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
|
||
}
|
||
}
|
||
|
||
// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
|
||
func taskBillingOther(task *model.Task) map[string]interface{} {
|
||
other := make(map[string]interface{})
|
||
if bc := task.PrivateData.BillingContext; bc != nil {
|
||
other["model_price"] = bc.ModelPrice
|
||
other["group_ratio"] = bc.GroupRatio
|
||
if len(bc.OtherRatios) > 0 {
|
||
for k, v := range bc.OtherRatios {
|
||
other[k] = v
|
||
}
|
||
}
|
||
}
|
||
props := task.Properties
|
||
if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
|
||
other["is_model_mapped"] = true
|
||
other["upstream_model_name"] = props.UpstreamModelName
|
||
}
|
||
return other
|
||
}
|
||
|
||
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
|
||
func taskModelName(task *model.Task) string {
|
||
if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
|
||
return bc.OriginModelName
|
||
}
|
||
return task.Properties.OriginModelName
|
||
}
|
||
|
||
// RefundTaskQuota 统一的任务失败退款逻辑。
|
||
// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
|
||
func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
|
||
quota := task.Quota
|
||
if quota == 0 {
|
||
return
|
||
}
|
||
|
||
// 1. 退还资金来源(钱包或订阅)
|
||
if err := taskAdjustFunding(task, -quota); err != nil {
|
||
logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
|
||
return
|
||
}
|
||
|
||
// 2. 退还令牌额度
|
||
taskAdjustTokenQuota(ctx, task, -quota)
|
||
|
||
// 3. 记录日志
|
||
other := taskBillingOther(task)
|
||
other["task_id"] = task.TaskID
|
||
other["reason"] = reason
|
||
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
|
||
UserId: task.UserId,
|
||
LogType: model.LogTypeRefund,
|
||
Content: "",
|
||
ChannelId: task.ChannelId,
|
||
ModelName: taskModelName(task),
|
||
Quota: quota,
|
||
TokenId: task.PrivateData.TokenId,
|
||
Group: task.Group,
|
||
Other: other,
|
||
})
|
||
}
|
||
|
||
// RecalculateTaskQuota 通用的异步差额结算。
|
||
// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
|
||
// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
|
||
func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
|
||
if actualQuota <= 0 {
|
||
return
|
||
}
|
||
preConsumedQuota := task.Quota
|
||
quotaDelta := actualQuota - preConsumedQuota
|
||
|
||
if quotaDelta == 0 {
|
||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)",
|
||
task.TaskID, logger.LogQuota(actualQuota), reason))
|
||
return
|
||
}
|
||
|
||
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)",
|
||
task.TaskID,
|
||
logger.LogQuota(quotaDelta),
|
||
logger.LogQuota(actualQuota),
|
||
logger.LogQuota(preConsumedQuota),
|
||
reason,
|
||
))
|
||
|
||
// 调整资金来源
|
||
if err := taskAdjustFunding(task, quotaDelta); err != nil {
|
||
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
|
||
return
|
||
}
|
||
|
||
// 调整令牌额度
|
||
taskAdjustTokenQuota(ctx, task, quotaDelta)
|
||
|
||
task.Quota = actualQuota
|
||
|
||
var logType int
|
||
var logQuota int
|
||
if quotaDelta > 0 {
|
||
logType = model.LogTypeConsume
|
||
logQuota = quotaDelta
|
||
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
|
||
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
|
||
} else {
|
||
logType = model.LogTypeRefund
|
||
logQuota = -quotaDelta
|
||
}
|
||
other := taskBillingOther(task)
|
||
other["task_id"] = task.TaskID
|
||
other["reason"] = reason
|
||
other["pre_consumed_quota"] = preConsumedQuota
|
||
other["actual_quota"] = actualQuota
|
||
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
|
||
UserId: task.UserId,
|
||
LogType: logType,
|
||
Content: "",
|
||
ChannelId: task.ChannelId,
|
||
ModelName: taskModelName(task),
|
||
Quota: logQuota,
|
||
TokenId: task.PrivateData.TokenId,
|
||
Group: task.Group,
|
||
Other: other,
|
||
})
|
||
}
|
||
|
||
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
|
||
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
|
||
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
|
||
func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
|
||
if totalTokens <= 0 {
|
||
return
|
||
}
|
||
|
||
modelName := taskModelName(task)
|
||
|
||
// 获取模型价格和倍率
|
||
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
|
||
// 只有配置了倍率(非固定价格)时才按 token 重新计费
|
||
if !hasRatioSetting || modelRatio <= 0 {
|
||
return
|
||
}
|
||
|
||
// 获取用户和组的倍率信息
|
||
group := task.Group
|
||
if group == "" {
|
||
user, err := model.GetUserById(task.UserId, false)
|
||
if err == nil {
|
||
group = user.Group
|
||
}
|
||
}
|
||
if group == "" {
|
||
return
|
||
}
|
||
|
||
groupRatio := ratio_setting.GetGroupRatio(group)
|
||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
|
||
|
||
var finalGroupRatio float64
|
||
if hasUserGroupRatio {
|
||
finalGroupRatio = userGroupRatio
|
||
} else {
|
||
finalGroupRatio = groupRatio
|
||
}
|
||
|
||
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
|
||
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
|
||
|
||
reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
|
||
RecalculateTaskQuota(ctx, task, actualQuota, reason)
|
||
}
|