Files
new-api/service/task_billing.go

286 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/ratio_setting"
"github.com/gin-gonic/gin"
)
// LogTaskConsumption 记录任务消费日志和统计信息(仅记录,不涉及实际扣费)。
// 实际扣费已由 BillingSessionPreConsumeBilling + SettleBilling完成。
func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("操作 %s", info.Action)
// 支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
}
other := make(map[string]interface{})
other["request_path"] = c.Request.URL.Path
other["model_price"] = info.PriceData.ModelPrice
other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio
if info.PriceData.GroupRatioInfo.HasSpecialRatio {
other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio
}
if info.IsModelMapped {
other["is_model_mapped"] = true
other["upstream_model_name"] = info.UpstreamModelName
}
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: info.OriginModelName,
TokenName: tokenName,
Quota: info.PriceData.Quota,
Content: logContent,
TokenId: info.TokenId,
Group: info.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota)
model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota)
}
// ---------------------------------------------------------------------------
// 异步任务计费辅助函数
// ---------------------------------------------------------------------------
// resolveTokenKey 通过 TokenId 运行时获取令牌 Key用于 Redis 缓存操作)。
// 如果令牌已被删除或查询失败,返回空字符串。
func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string {
token, err := model.GetTokenById(tokenId)
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error()))
return ""
}
return token.Key
}
// taskIsSubscription 判断任务是否通过订阅计费。
func taskIsSubscription(task *model.Task) bool {
return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0
}
// taskAdjustFunding 调整任务的资金来源钱包或订阅delta > 0 表示扣费delta < 0 表示退还。
func taskAdjustFunding(task *model.Task, delta int) error {
if taskIsSubscription(task) {
return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta))
}
if delta > 0 {
return model.DecreaseUserQuota(task.UserId, delta)
}
return model.IncreaseUserQuota(task.UserId, -delta, false)
}
// taskAdjustTokenQuota 调整任务的令牌额度delta > 0 表示扣费delta < 0 表示退还。
// 需要通过 resolveTokenKey 运行时获取 key不从 PrivateData 中读取)。
func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) {
if task.PrivateData.TokenId <= 0 || delta == 0 {
return
}
tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID)
if tokenKey == "" {
return
}
var err error
if delta > 0 {
err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta)
} else {
err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta)
}
if err != nil {
logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error()))
}
}
// taskBillingOther 从 task 的 BillingContext 构建日志 Other 字段。
func taskBillingOther(task *model.Task) map[string]interface{} {
other := make(map[string]interface{})
if bc := task.PrivateData.BillingContext; bc != nil {
other["model_price"] = bc.ModelPrice
other["group_ratio"] = bc.GroupRatio
if len(bc.OtherRatios) > 0 {
for k, v := range bc.OtherRatios {
other[k] = v
}
}
}
props := task.Properties
if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName {
other["is_model_mapped"] = true
other["upstream_model_name"] = props.UpstreamModelName
}
return other
}
// taskModelName 从 BillingContext 或 Properties 中获取模型名称。
func taskModelName(task *model.Task) string {
if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" {
return bc.OriginModelName
}
return task.Properties.OriginModelName
}
// RefundTaskQuota 统一的任务失败退款逻辑。
// 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。
func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) {
quota := task.Quota
if quota == 0 {
return
}
// 1. 退还资金来源(钱包或订阅)
if err := taskAdjustFunding(task, -quota); err != nil {
logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error()))
return
}
// 2. 退还令牌额度
taskAdjustTokenQuota(ctx, task, -quota)
// 3. 记录日志
other := taskBillingOther(task)
other["task_id"] = task.TaskID
other["reason"] = reason
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
UserId: task.UserId,
LogType: model.LogTypeRefund,
Content: "",
ChannelId: task.ChannelId,
ModelName: taskModelName(task),
Quota: quota,
TokenId: task.PrivateData.TokenId,
Group: task.Group,
Other: other,
})
}
// RecalculateTaskQuota 通用的异步差额结算。
// actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。
// reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。
func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) {
if actualQuota <= 0 {
return
}
preConsumedQuota := task.Quota
quotaDelta := actualQuota - preConsumedQuota
if quotaDelta == 0 {
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s%s",
task.TaskID, logger.LogQuota(actualQuota), reason))
return
}
logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算delta=%s实际%s预扣%s%s",
task.TaskID,
logger.LogQuota(quotaDelta),
logger.LogQuota(actualQuota),
logger.LogQuota(preConsumedQuota),
reason,
))
// 调整资金来源
if err := taskAdjustFunding(task, quotaDelta); err != nil {
logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error()))
return
}
// 调整令牌额度
taskAdjustTokenQuota(ctx, task, quotaDelta)
task.Quota = actualQuota
var logType int
var logQuota int
if quotaDelta > 0 {
logType = model.LogTypeConsume
logQuota = quotaDelta
model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta)
model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta)
} else {
logType = model.LogTypeRefund
logQuota = -quotaDelta
}
other := taskBillingOther(task)
other["task_id"] = task.TaskID
//other["reason"] = reason
other["pre_consumed_quota"] = preConsumedQuota
other["actual_quota"] = actualQuota
model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{
UserId: task.UserId,
LogType: logType,
Content: reason,
ChannelId: task.ChannelId,
ModelName: taskModelName(task),
Quota: logQuota,
TokenId: task.PrivateData.TokenId,
Group: task.Group,
Other: other,
})
}
// RecalculateTaskQuotaByTokens 根据实际 token 消耗重新计费(异步差额结算)。
// 当任务成功且返回了 totalTokens 时,根据模型倍率和分组倍率重新计算实际扣费额度,
// 与预扣费的差额进行补扣或退还。支持钱包和订阅计费来源。
func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) {
if totalTokens <= 0 {
return
}
modelName := taskModelName(task)
// 获取模型价格和倍率
modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName)
// 只有配置了倍率(非固定价格)时才按 token 重新计费
if !hasRatioSetting || modelRatio <= 0 {
return
}
// 获取用户和组的倍率信息
group := task.Group
if group == "" {
user, err := model.GetUserById(task.UserId, false)
if err == nil {
group = user.Group
}
}
if group == "" {
return
}
groupRatio := ratio_setting.GetGroupRatio(group)
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group)
var finalGroupRatio float64
if hasUserGroupRatio {
finalGroupRatio = userGroupRatio
} else {
finalGroupRatio = groupRatio
}
// 计算实际应扣费额度: totalTokens * modelRatio * groupRatio
actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio)
reason := fmt.Sprintf("token重算tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio)
RecalculateTaskQuota(ctx, task, actualQuota, reason)
}