mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:40:59 +00:00
refactor: 抽象统一计费会话 BillingSession
将散落在多个文件中的预扣费/结算/退款逻辑抽象为统一的 BillingSession 生命周期管理: - 新增 BillingSettler 接口 (relay/common/billing.go) 避免循环引用 - 新增 FundingSource 接口 + WalletFunding / SubscriptionFunding 实现 (service/funding_source.go) - 新增 BillingSession 封装预扣/结算/退款原子操作 (service/billing_session.go) - 新增 SettleBilling 统一结算辅助函数,替换各 handler 中的 quotaDelta 模式 - 重写 PreConsumeBilling 为 BillingSession 工厂入口 - controller/relay.go 退款守卫改用 BillingSession.Refund() 修复的 Bug: - 令牌额度泄漏:PreConsumeTokenQuota 成功但 DecreaseUserQuota 失败时未回滚 - 订阅退款遗漏:FinalPreConsumedQuota=0 但 SubscriptionPreConsumed>0 时跳过退款 - 订阅多扣费:subConsume 强制为 1 但 FinalPreConsumedQuota 不同步 - 退款路径不统一:钱包/订阅退款逻辑现统一由 FundingSource.Refund 分派
This commit is contained in:
@@ -2,12 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,89 +14,61 @@ const (
|
||||
BillingSourceSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PreConsumeBilling decides whether to pre-consume from subscription or wallet based on user preference.
|
||||
// It also always pre-consumes token quota in quota units (same as legacy flow).
|
||||
// PreConsumeBilling 根据用户计费偏好创建 BillingSession 并执行预扣费。
|
||||
// 会话存储在 relayInfo.Billing 上,供后续 Settle / Refund 使用。
|
||||
func PreConsumeBilling(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
if relayInfo == nil {
|
||||
return types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
|
||||
trySubscription := func() *types.NewAPIError {
|
||||
quotaType := 0
|
||||
// For total quota: consume preConsumedQuota quota units.
|
||||
subConsume := int64(preConsumedQuota)
|
||||
if subConsume <= 0 {
|
||||
subConsume = 1
|
||||
}
|
||||
|
||||
// Pre-consume token quota in quota units to keep token limits consistent.
|
||||
if preConsumedQuota > 0 {
|
||||
if err := PreConsumeTokenQuota(relayInfo, preConsumedQuota); err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
}
|
||||
|
||||
res, err := model.PreConsumeUserSubscription(relayInfo.RequestId, relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume)
|
||||
if err != nil {
|
||||
// revert token pre-consume when subscription fails
|
||||
if preConsumedQuota > 0 && !relayInfo.IsPlayground {
|
||||
_ = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, preConsumedQuota)
|
||||
}
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("订阅预扣失败: %s", errMsg), types.ErrorCodeQueryDataError, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
relayInfo.BillingSource = BillingSourceSubscription
|
||||
relayInfo.SubscriptionId = res.UserSubscriptionId
|
||||
relayInfo.SubscriptionPreConsumed = res.PreConsumed
|
||||
relayInfo.SubscriptionPostDelta = 0
|
||||
relayInfo.SubscriptionAmountTotal = res.AmountTotal
|
||||
relayInfo.SubscriptionAmountUsedAfterPreConsume = res.AmountUsedAfter
|
||||
if planInfo, err := model.GetSubscriptionPlanInfoByUserSubscriptionId(res.UserSubscriptionId); err == nil && planInfo != nil {
|
||||
relayInfo.SubscriptionPlanId = planInfo.PlanId
|
||||
relayInfo.SubscriptionPlanTitle = planInfo.PlanTitle
|
||||
}
|
||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 使用订阅计费预扣:订阅=%d,token_quota=%d", relayInfo.UserId, res.PreConsumed, preConsumedQuota))
|
||||
return nil
|
||||
}
|
||||
|
||||
tryWallet := func() *types.NewAPIError {
|
||||
relayInfo.BillingSource = BillingSourceWallet
|
||||
relayInfo.SubscriptionId = 0
|
||||
relayInfo.SubscriptionPreConsumed = 0
|
||||
return PreConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
}
|
||||
|
||||
switch pref {
|
||||
case "subscription_only":
|
||||
return trySubscription()
|
||||
case "wallet_only":
|
||||
return tryWallet()
|
||||
case "wallet_first":
|
||||
if err := tryWallet(); err != nil {
|
||||
// only fallback for insufficient wallet quota
|
||||
if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
|
||||
return trySubscription()
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
case "subscription_first":
|
||||
fallthrough
|
||||
default:
|
||||
if err := trySubscription(); err != nil {
|
||||
// fallback only when subscription not available/insufficient
|
||||
if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
|
||||
return tryWallet()
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
session, apiErr := NewBillingSession(c, relayInfo, preConsumedQuota)
|
||||
if apiErr != nil {
|
||||
return apiErr
|
||||
}
|
||||
relayInfo.Billing = session
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SettleBilling — 后结算辅助函数
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SettleBilling 执行计费结算。如果 RelayInfo 上有 BillingSession 则通过 session 结算,
|
||||
// 否则回退到旧的 PostConsumeQuota 路径(兼容按次计费等场景)。
|
||||
func SettleBilling(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, actualQuota int) error {
|
||||
if relayInfo.Billing != nil {
|
||||
preConsumed := relayInfo.Billing.GetPreConsumedQuota()
|
||||
delta := actualQuota - preConsumed
|
||||
|
||||
if delta > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(delta),
|
||||
logger.FormatQuota(actualQuota),
|
||||
logger.FormatQuota(preConsumed),
|
||||
))
|
||||
} else if delta < 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(-delta),
|
||||
logger.FormatQuota(actualQuota),
|
||||
logger.FormatQuota(preConsumed),
|
||||
))
|
||||
} else {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费与实际消耗一致,无需调整:%s(按次计费)",
|
||||
logger.FormatQuota(actualQuota),
|
||||
))
|
||||
}
|
||||
|
||||
if err := relayInfo.Billing.Settle(actualQuota); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送额度通知
|
||||
if actualQuota != 0 {
|
||||
checkAndSendQuotaNotify(relayInfo, actualQuota-preConsumed, preConsumed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 回退:无 BillingSession 时使用旧路径
|
||||
quotaDelta := actualQuota - relayInfo.FinalPreConsumedQuota
|
||||
if quotaDelta != 0 {
|
||||
return PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
321
service/billing_session.go
Normal file
321
service/billing_session.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BillingSession — 统一计费会话
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// BillingSession 封装单次请求的预扣费/结算/退款生命周期。
|
||||
// 实现 relaycommon.BillingSettler 接口。
|
||||
type BillingSession struct {
|
||||
relayInfo *relaycommon.RelayInfo
|
||||
funding FundingSource
|
||||
preConsumedQuota int // 实际预扣额度(信任用户可能为 0)
|
||||
tokenConsumed int // 令牌额度实际扣减量
|
||||
settled bool // Settle 已调用
|
||||
refunded bool // Refund 已调用
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Settle 根据实际消耗额度进行结算。
|
||||
func (s *BillingSession) Settle(actualQuota int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.settled {
|
||||
return nil
|
||||
}
|
||||
delta := actualQuota - s.preConsumedQuota
|
||||
if delta == 0 {
|
||||
s.settled = true
|
||||
return nil
|
||||
}
|
||||
// 1) 调整资金来源
|
||||
if err := s.funding.Settle(delta); err != nil {
|
||||
return err
|
||||
}
|
||||
// 2) 调整令牌额度
|
||||
if !s.relayInfo.IsPlayground {
|
||||
if delta > 0 {
|
||||
if err := model.DecreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, delta); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, -delta); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
// 3) 更新 relayInfo 上的订阅 PostDelta(用于日志)
|
||||
if s.funding.Source() == BillingSourceSubscription {
|
||||
s.relayInfo.SubscriptionPostDelta += int64(delta)
|
||||
}
|
||||
s.settled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refund 退还所有预扣费,幂等安全,异步执行。
|
||||
func (s *BillingSession) Refund(c *gin.Context) {
|
||||
s.mu.Lock()
|
||||
if s.settled || s.refunded || !s.needsRefundLocked() {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.refunded = true
|
||||
s.mu.Unlock()
|
||||
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, funding=%s)",
|
||||
s.relayInfo.UserId,
|
||||
logger.FormatQuota(s.tokenConsumed),
|
||||
s.funding.Source(),
|
||||
))
|
||||
|
||||
// 复制需要的值到闭包中
|
||||
tokenId := s.relayInfo.TokenId
|
||||
tokenKey := s.relayInfo.TokenKey
|
||||
isPlayground := s.relayInfo.IsPlayground
|
||||
tokenConsumed := s.tokenConsumed
|
||||
funding := s.funding
|
||||
|
||||
gopool.Go(func() {
|
||||
// 1) 退还资金来源
|
||||
if err := funding.Refund(); err != nil {
|
||||
common.SysLog("error refunding billing source: " + err.Error())
|
||||
}
|
||||
// 2) 退还令牌额度
|
||||
if tokenConsumed > 0 && !isPlayground {
|
||||
if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil {
|
||||
common.SysLog("error refunding token quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// NeedsRefund 返回是否存在需要退还的预扣状态。
|
||||
func (s *BillingSession) NeedsRefund() bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.needsRefundLocked()
|
||||
}
|
||||
|
||||
func (s *BillingSession) needsRefundLocked() bool {
|
||||
if s.settled || s.refunded {
|
||||
return false
|
||||
}
|
||||
if s.tokenConsumed > 0 {
|
||||
return true
|
||||
}
|
||||
// 订阅可能在 tokenConsumed=0 时仍预扣了额度
|
||||
if sub, ok := s.funding.(*SubscriptionFunding); ok && sub.preConsumed > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreConsumedQuota 返回实际预扣的额度。
|
||||
func (s *BillingSession) GetPreConsumedQuota() int {
|
||||
return s.preConsumedQuota
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// PreConsume — 统一预扣费入口(含信任额度旁路)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// preConsume 执行预扣费:信任检查 -> 令牌预扣 -> 资金来源预扣。
|
||||
// 任一步骤失败时原子回滚已完成的步骤。
|
||||
func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIError {
|
||||
effectiveQuota := quota
|
||||
|
||||
// ---- 信任额度旁路 ----
|
||||
if s.shouldTrust(c) {
|
||||
effectiveQuota = 0
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source()))
|
||||
} else if effectiveQuota > 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 需要预扣费 %s (funding=%s)", s.relayInfo.UserId, logger.FormatQuota(effectiveQuota), s.funding.Source()))
|
||||
}
|
||||
|
||||
// ---- 1) 预扣令牌额度 ----
|
||||
if effectiveQuota > 0 {
|
||||
if err := PreConsumeTokenQuota(s.relayInfo, effectiveQuota); err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
s.tokenConsumed = effectiveQuota
|
||||
}
|
||||
|
||||
// ---- 2) 预扣资金来源 ----
|
||||
if err := s.funding.PreConsume(effectiveQuota); err != nil {
|
||||
// 回滚令牌额度
|
||||
if s.tokenConsumed > 0 && !s.relayInfo.IsPlayground {
|
||||
_ = model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, s.tokenConsumed)
|
||||
s.tokenConsumed = 0
|
||||
}
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if strings.Contains(errMsg, "用户额度不足") || strings.Contains(errMsg, "预扣费额度失败") {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
s.preConsumedQuota = effectiveQuota
|
||||
|
||||
// ---- 同步 RelayInfo 兼容字段 ----
|
||||
s.syncRelayInfo()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldTrust 统一信任额度检查,适用于钱包和订阅。
|
||||
func (s *BillingSession) shouldTrust(c *gin.Context) bool {
|
||||
trustQuota := common.GetTrustQuota()
|
||||
if trustQuota <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查令牌是否充足
|
||||
tokenTrusted := s.relayInfo.TokenUnlimited
|
||||
if !tokenTrusted {
|
||||
tokenQuota := c.GetInt("token_quota")
|
||||
tokenTrusted = tokenQuota > trustQuota
|
||||
}
|
||||
if !tokenTrusted {
|
||||
return false
|
||||
}
|
||||
|
||||
switch s.funding.Source() {
|
||||
case BillingSourceWallet:
|
||||
return s.relayInfo.UserQuota > trustQuota
|
||||
case BillingSourceSubscription:
|
||||
// 订阅暂不支持信任旁路(订阅剩余额度需要额外查询,且预扣开销小)
|
||||
// 后续可以在此处添加订阅信任逻辑
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// syncRelayInfo 将 BillingSession 的状态同步到 RelayInfo 的兼容字段上。
|
||||
func (s *BillingSession) syncRelayInfo() {
|
||||
info := s.relayInfo
|
||||
info.FinalPreConsumedQuota = s.preConsumedQuota
|
||||
info.BillingSource = s.funding.Source()
|
||||
|
||||
if sub, ok := s.funding.(*SubscriptionFunding); ok {
|
||||
info.SubscriptionId = sub.subscriptionId
|
||||
info.SubscriptionPreConsumed = sub.preConsumed
|
||||
info.SubscriptionPostDelta = 0
|
||||
info.SubscriptionAmountTotal = sub.AmountTotal
|
||||
info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter
|
||||
info.SubscriptionPlanId = sub.PlanId
|
||||
info.SubscriptionPlanTitle = sub.PlanTitle
|
||||
} else {
|
||||
info.SubscriptionId = 0
|
||||
info.SubscriptionPreConsumed = 0
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// NewBillingSession 工厂 — 根据计费偏好创建会话并处理回退
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// NewBillingSession 根据用户计费偏好创建 BillingSession,处理 subscription_first / wallet_first 的回退。
|
||||
func NewBillingSession(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) (*BillingSession, *types.NewAPIError) {
|
||||
if relayInfo == nil {
|
||||
return nil, types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference)
|
||||
|
||||
// 钱包路径需要先检查用户额度
|
||||
tryWallet := func() (*BillingSession, *types.NewAPIError) {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)),
|
||||
types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
|
||||
types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return nil, types.NewErrorWithStatusCode(
|
||||
fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)),
|
||||
types.ErrorCodeInsufficientUserQuota, http.StatusForbidden,
|
||||
types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
relayInfo.UserQuota = userQuota
|
||||
|
||||
session := &BillingSession{
|
||||
relayInfo: relayInfo,
|
||||
funding: &WalletFunding{userId: relayInfo.UserId},
|
||||
}
|
||||
if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
|
||||
return nil, apiErr
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
trySubscription := func() (*BillingSession, *types.NewAPIError) {
|
||||
subConsume := int64(preConsumedQuota)
|
||||
if subConsume <= 0 {
|
||||
subConsume = 1
|
||||
}
|
||||
session := &BillingSession{
|
||||
relayInfo: relayInfo,
|
||||
funding: &SubscriptionFunding{
|
||||
requestId: relayInfo.RequestId,
|
||||
userId: relayInfo.UserId,
|
||||
modelName: relayInfo.OriginModelName,
|
||||
amount: subConsume,
|
||||
},
|
||||
}
|
||||
if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil {
|
||||
return nil, apiErr
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
switch pref {
|
||||
case "subscription_only":
|
||||
return trySubscription()
|
||||
case "wallet_only":
|
||||
return tryWallet()
|
||||
case "wallet_first":
|
||||
session, err := tryWallet()
|
||||
if err != nil {
|
||||
if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
|
||||
return trySubscription()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
case "subscription_first":
|
||||
fallthrough
|
||||
default:
|
||||
session, err := trySubscription()
|
||||
if err != nil {
|
||||
if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota {
|
||||
return tryWallet()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
}
|
||||
137
service/funding_source.go
Normal file
137
service/funding_source.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// FundingSource — 资金来源接口(钱包 or 订阅)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// FundingSource 抽象了预扣费的资金来源。
|
||||
type FundingSource interface {
|
||||
// Source 返回资金来源标识:"wallet" 或 "subscription"
|
||||
Source() string
|
||||
// PreConsume 从该资金来源预扣 amount 额度
|
||||
PreConsume(amount int) error
|
||||
// Settle 根据差额调整资金来源(正数补扣,负数退还)
|
||||
Settle(delta int) error
|
||||
// Refund 退还所有预扣费
|
||||
Refund() error
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WalletFunding — 钱包资金来源实现
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type WalletFunding struct {
|
||||
userId int
|
||||
consumed int // 实际预扣的用户额度
|
||||
}
|
||||
|
||||
func (w *WalletFunding) Source() string { return BillingSourceWallet }
|
||||
|
||||
func (w *WalletFunding) PreConsume(amount int) error {
|
||||
if amount <= 0 {
|
||||
return nil
|
||||
}
|
||||
if err := model.DecreaseUserQuota(w.userId, amount); err != nil {
|
||||
return err
|
||||
}
|
||||
w.consumed = amount
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WalletFunding) Settle(delta int) error {
|
||||
if delta == 0 {
|
||||
return nil
|
||||
}
|
||||
if delta > 0 {
|
||||
return model.DecreaseUserQuota(w.userId, delta)
|
||||
}
|
||||
return model.IncreaseUserQuota(w.userId, -delta, false)
|
||||
}
|
||||
|
||||
func (w *WalletFunding) Refund() error {
|
||||
if w.consumed <= 0 {
|
||||
return nil
|
||||
}
|
||||
return model.IncreaseUserQuota(w.userId, w.consumed, false)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SubscriptionFunding — 订阅资金来源实现
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type SubscriptionFunding struct {
|
||||
requestId string
|
||||
userId int
|
||||
modelName string
|
||||
amount int64 // 预扣的订阅额度(subConsume)
|
||||
subscriptionId int
|
||||
preConsumed int64
|
||||
// 以下字段在 PreConsume 成功后填充,供 RelayInfo 同步使用
|
||||
AmountTotal int64
|
||||
AmountUsedAfter int64
|
||||
PlanId int
|
||||
PlanTitle string
|
||||
}
|
||||
|
||||
func (s *SubscriptionFunding) Source() string { return BillingSourceSubscription }
|
||||
|
||||
func (s *SubscriptionFunding) PreConsume(_ int) error {
|
||||
// amount 参数被忽略,使用内部 s.amount(已在构造时根据 preConsumedQuota 计算)
|
||||
res, err := model.PreConsumeUserSubscription(s.requestId, s.userId, s.modelName, 0, s.amount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.subscriptionId = res.UserSubscriptionId
|
||||
s.preConsumed = res.PreConsumed
|
||||
s.AmountTotal = res.AmountTotal
|
||||
s.AmountUsedAfter = res.AmountUsedAfter
|
||||
// 获取订阅计划信息
|
||||
if planInfo, err := model.GetSubscriptionPlanInfoByUserSubscriptionId(res.UserSubscriptionId); err == nil && planInfo != nil {
|
||||
s.PlanId = planInfo.PlanId
|
||||
s.PlanTitle = planInfo.PlanTitle
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubscriptionFunding) Settle(delta int) error {
|
||||
if delta == 0 {
|
||||
return nil
|
||||
}
|
||||
return model.PostConsumeUserSubscriptionDelta(s.subscriptionId, int64(delta))
|
||||
}
|
||||
|
||||
func (s *SubscriptionFunding) Refund() error {
|
||||
if s.preConsumed <= 0 {
|
||||
return nil
|
||||
}
|
||||
return refundWithRetry(func() error {
|
||||
return model.RefundSubscriptionPreConsume(s.requestId)
|
||||
})
|
||||
}
|
||||
|
||||
// refundWithRetry 尝试多次执行退款操作以提高成功率,只能用于基于事务的退款函数!!!!!!
|
||||
// try to refund with retries, only for refund functions based on transactions!!!
|
||||
func refundWithRetry(fn func() error) error {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
const maxAttempts = 3
|
||||
var lastErr error
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
if err := fn(); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
if i < maxAttempts-1 {
|
||||
time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
// Always refund subscription pre-consumed (can be non-zero even when FinalPreConsumedQuota is 0)
|
||||
needRefundSub := relayInfo.BillingSource == BillingSourceSubscription && relayInfo.SubscriptionId != 0 && relayInfo.SubscriptionPreConsumed > 0
|
||||
needRefundToken := relayInfo.FinalPreConsumedQuota != 0
|
||||
if !needRefundSub && !needRefundToken {
|
||||
return
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, subscription=%d)",
|
||||
relayInfo.UserId,
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
relayInfo.SubscriptionPreConsumed,
|
||||
))
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
if relayInfoCopy.BillingSource == BillingSourceSubscription {
|
||||
if needRefundSub {
|
||||
if err := refundWithRetry(func() error {
|
||||
return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId)
|
||||
}); err != nil {
|
||||
common.SysLog("error refund subscription pre-consume: " + err.Error())
|
||||
}
|
||||
}
|
||||
// refund token quota only
|
||||
if needRefundToken && !relayInfoCopy.IsPlayground {
|
||||
_ = model.IncreaseTokenQuota(relayInfoCopy.TokenId, relayInfoCopy.TokenKey, relayInfoCopy.FinalPreConsumedQuota)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// wallet refund uses existing path (user quota + token quota)
|
||||
if needRefundToken {
|
||||
err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func refundWithRetry(fn func() error) error {
|
||||
if fn == nil {
|
||||
return nil
|
||||
}
|
||||
const maxAttempts = 3
|
||||
var lastErr error
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
if err := fn(); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
lastErr = err
|
||||
}
|
||||
if i < maxAttempts-1 {
|
||||
time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||
// It returns the pre-consumed quota if successful, or an error if not.
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
|
||||
trustQuota := common.GetTrustQuota()
|
||||
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > trustQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
// 非无限令牌,判断令牌额度是否充足
|
||||
tokenQuota := c.GetInt("token_quota")
|
||||
if tokenQuota > trustQuota {
|
||||
// 令牌额度充足,信任令牌
|
||||
preConsumedQuota = 0
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
|
||||
}
|
||||
} else {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
if preConsumedQuota > 0 {
|
||||
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
||||
}
|
||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||
return nil
|
||||
}
|
||||
@@ -307,27 +307,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
} else if quotaDelta < 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(-quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
}
|
||||
|
||||
if quotaDelta != 0 {
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
@@ -432,27 +413,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, u
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
} else if quotaDelta < 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(-quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
}
|
||||
|
||||
if quotaDelta != 0 {
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||||
if err != nil {
|
||||
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
if err := SettleBilling(ctx, relayInfo, quota); err != nil {
|
||||
logger.LogError(ctx, "error settling billing: "+err.Error())
|
||||
}
|
||||
|
||||
logModel := relayInfo.OriginModelName
|
||||
|
||||
Reference in New Issue
Block a user