diff --git a/controller/relay.go b/controller/relay.go index 5310a9fba..2d5ae7df6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -170,8 +170,8 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { // Only return quota if downstream failed and quota was actually pre-consumed if newAPIError != nil { newAPIError = service.NormalizeViolationFeeError(newAPIError) - if relayInfo.FinalPreConsumedQuota != 0 { - service.ReturnPreConsumedQuota(c, relayInfo) + if relayInfo.Billing != nil { + relayInfo.Billing.Refund(c) } service.ChargeViolationFeeIfNeeded(c, relayInfo, newAPIError) } diff --git a/relay/common/billing.go b/relay/common/billing.go new file mode 100644 index 000000000..78f5cb195 --- /dev/null +++ b/relay/common/billing.go @@ -0,0 +1,21 @@ +package common + +import "github.com/gin-gonic/gin" + +// BillingSettler 抽象计费会话的生命周期操作。 +// 由 service.BillingSession 实现,存储在 RelayInfo 上以避免循环引用。 +type BillingSettler interface { + // Settle 根据实际消耗额度进行结算,计算 delta = actualQuota - preConsumedQuota, + // 同时调整资金来源(钱包/订阅)和令牌额度。 + Settle(actualQuota int) error + + // Refund 退还所有预扣费额度(资金来源 + 令牌),幂等安全。 + // 通过 gopool 异步执行。如果已经结算或退款则不做任何操作。 + Refund(c *gin.Context) + + // NeedsRefund 返回会话是否存在需要退还的预扣状态(未结算且未退款)。 + NeedsRefund() bool + + // GetPreConsumedQuota 返回实际预扣的额度值(信任用户可能为 0)。 + GetPreConsumedQuota() int +} diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index f5c1d769e..96f68d471 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -115,6 +115,9 @@ type RelayInfo struct { SendResponseCount int ReceivedResponseCount int FinalPreConsumedQuota int // 最终预消耗的配额 + // Billing 是计费会话,封装了预扣费/结算/退款的统一生命周期。 + // 免费模型和按次计费(MJ/Task)时为 nil。 + Billing BillingSettler // BillingSource indicates whether this request is billed from wallet quota or subscription. // "" or "wallet" => wallet; "subscription" => subscription BillingSource string diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 74abfe5b2..21180d8de 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -423,29 +423,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - relayInfo.FinalPreConsumedQuota - - //logger.LogInfo(ctx, fmt.Sprintf("request quota delta: %s", logger.FormatQuota(quotaDelta))) - - if quotaDelta > 0 { - logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", - logger.FormatQuota(quotaDelta), - logger.FormatQuota(quota), - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - )) - } else if quotaDelta < 0 { - logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", - logger.FormatQuota(-quotaDelta), - logger.FormatQuota(quota), - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - )) - } - - if quotaDelta != 0 { - err := service.PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) - if err != nil { - logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } + if err := service.SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling billing: "+err.Error()) } logModel := modelName diff --git a/service/billing.go b/service/billing.go index c7b3c6d8a..f2351a2ec 100644 --- a/service/billing.go +++ b/service/billing.go @@ -2,12 +2,8 @@ package service import ( "fmt" - "net/http" - "strings" - "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" - "github.com/QuantumNous/new-api/model" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" @@ -18,89 +14,61 @@ const ( BillingSourceSubscription = "subscription" ) -// PreConsumeBilling decides whether to pre-consume from subscription or wallet based on user preference. -// It also always pre-consumes token quota in quota units (same as legacy flow). +// PreConsumeBilling 根据用户计费偏好创建 BillingSession 并执行预扣费。 +// 会话存储在 relayInfo.Billing 上,供后续 Settle / Refund 使用。 func PreConsumeBilling(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { - if relayInfo == nil { - return types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) - } - - pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference) - trySubscription := func() *types.NewAPIError { - quotaType := 0 - // For total quota: consume preConsumedQuota quota units. - subConsume := int64(preConsumedQuota) - if subConsume <= 0 { - subConsume = 1 - } - - // Pre-consume token quota in quota units to keep token limits consistent. - if preConsumedQuota > 0 { - if err := PreConsumeTokenQuota(relayInfo, preConsumedQuota); err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - } - - res, err := model.PreConsumeUserSubscription(relayInfo.RequestId, relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume) - if err != nil { - // revert token pre-consume when subscription fails - if preConsumedQuota > 0 && !relayInfo.IsPlayground { - _ = model.IncreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, preConsumedQuota) - } - errMsg := err.Error() - if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") { - return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - return types.NewErrorWithStatusCode(fmt.Errorf("订阅预扣失败: %s", errMsg), types.ErrorCodeQueryDataError, http.StatusInternalServerError) - } - - relayInfo.BillingSource = BillingSourceSubscription - relayInfo.SubscriptionId = res.UserSubscriptionId - relayInfo.SubscriptionPreConsumed = res.PreConsumed - relayInfo.SubscriptionPostDelta = 0 - relayInfo.SubscriptionAmountTotal = res.AmountTotal - relayInfo.SubscriptionAmountUsedAfterPreConsume = res.AmountUsedAfter - if planInfo, err := model.GetSubscriptionPlanInfoByUserSubscriptionId(res.UserSubscriptionId); err == nil && planInfo != nil { - relayInfo.SubscriptionPlanId = planInfo.PlanId - relayInfo.SubscriptionPlanTitle = planInfo.PlanTitle - } - relayInfo.FinalPreConsumedQuota = preConsumedQuota - - logger.LogInfo(c, fmt.Sprintf("用户 %d 使用订阅计费预扣:订阅=%d,token_quota=%d", relayInfo.UserId, res.PreConsumed, preConsumedQuota)) - return nil - } - - tryWallet := func() *types.NewAPIError { - relayInfo.BillingSource = BillingSourceWallet - relayInfo.SubscriptionId = 0 - relayInfo.SubscriptionPreConsumed = 0 - return PreConsumeQuota(c, preConsumedQuota, relayInfo) - } - - switch pref { - case "subscription_only": - return trySubscription() - case "wallet_only": - return tryWallet() - case "wallet_first": - if err := tryWallet(); err != nil { - // only fallback for insufficient wallet quota - if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota { - return trySubscription() - } - return err - } - return nil - case "subscription_first": - fallthrough - default: - if err := trySubscription(); err != nil { - // fallback only when subscription not available/insufficient - if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota { - return tryWallet() - } - return err - } - return nil + session, apiErr := NewBillingSession(c, relayInfo, preConsumedQuota) + if apiErr != nil { + return apiErr } + relayInfo.Billing = session + return nil +} + +// --------------------------------------------------------------------------- +// SettleBilling — 后结算辅助函数 +// --------------------------------------------------------------------------- + +// SettleBilling 执行计费结算。如果 RelayInfo 上有 BillingSession 则通过 session 结算, +// 否则回退到旧的 PostConsumeQuota 路径(兼容按次计费等场景)。 +func SettleBilling(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, actualQuota int) error { + if relayInfo.Billing != nil { + preConsumed := relayInfo.Billing.GetPreConsumedQuota() + delta := actualQuota - preConsumed + + if delta > 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(delta), + logger.FormatQuota(actualQuota), + logger.FormatQuota(preConsumed), + )) + } else if delta < 0 { + logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", + logger.FormatQuota(-delta), + logger.FormatQuota(actualQuota), + logger.FormatQuota(preConsumed), + )) + } else { + logger.LogInfo(ctx, fmt.Sprintf("预扣费与实际消耗一致,无需调整:%s(按次计费)", + logger.FormatQuota(actualQuota), + )) + } + + if err := relayInfo.Billing.Settle(actualQuota); err != nil { + return err + } + + // 发送额度通知 + if actualQuota != 0 { + checkAndSendQuotaNotify(relayInfo, actualQuota-preConsumed, preConsumed) + } + return nil + } + + // 回退:无 BillingSession 时使用旧路径 + quotaDelta := actualQuota - relayInfo.FinalPreConsumedQuota + if quotaDelta != 0 { + return PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) + } + return nil } diff --git a/service/billing_session.go b/service/billing_session.go new file mode 100644 index 000000000..5dadf7808 --- /dev/null +++ b/service/billing_session.go @@ -0,0 +1,321 @@ +package service + +import ( + "fmt" + "net/http" + "strings" + "sync" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + relaycommon "github.com/QuantumNous/new-api/relay/common" + "github.com/QuantumNous/new-api/types" + + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" +) + +// --------------------------------------------------------------------------- +// BillingSession — 统一计费会话 +// --------------------------------------------------------------------------- + +// BillingSession 封装单次请求的预扣费/结算/退款生命周期。 +// 实现 relaycommon.BillingSettler 接口。 +type BillingSession struct { + relayInfo *relaycommon.RelayInfo + funding FundingSource + preConsumedQuota int // 实际预扣额度(信任用户可能为 0) + tokenConsumed int // 令牌额度实际扣减量 + settled bool // Settle 已调用 + refunded bool // Refund 已调用 + mu sync.Mutex +} + +// Settle 根据实际消耗额度进行结算。 +func (s *BillingSession) Settle(actualQuota int) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.settled { + return nil + } + delta := actualQuota - s.preConsumedQuota + if delta == 0 { + s.settled = true + return nil + } + // 1) 调整资金来源 + if err := s.funding.Settle(delta); err != nil { + return err + } + // 2) 调整令牌额度 + if !s.relayInfo.IsPlayground { + if delta > 0 { + if err := model.DecreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, delta); err != nil { + return err + } + } else { + if err := model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, -delta); err != nil { + return err + } + } + } + // 3) 更新 relayInfo 上的订阅 PostDelta(用于日志) + if s.funding.Source() == BillingSourceSubscription { + s.relayInfo.SubscriptionPostDelta += int64(delta) + } + s.settled = true + return nil +} + +// Refund 退还所有预扣费,幂等安全,异步执行。 +func (s *BillingSession) Refund(c *gin.Context) { + s.mu.Lock() + if s.settled || s.refunded || !s.needsRefundLocked() { + s.mu.Unlock() + return + } + s.refunded = true + s.mu.Unlock() + + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, funding=%s)", + s.relayInfo.UserId, + logger.FormatQuota(s.tokenConsumed), + s.funding.Source(), + )) + + // 复制需要的值到闭包中 + tokenId := s.relayInfo.TokenId + tokenKey := s.relayInfo.TokenKey + isPlayground := s.relayInfo.IsPlayground + tokenConsumed := s.tokenConsumed + funding := s.funding + + gopool.Go(func() { + // 1) 退还资金来源 + if err := funding.Refund(); err != nil { + common.SysLog("error refunding billing source: " + err.Error()) + } + // 2) 退还令牌额度 + if tokenConsumed > 0 && !isPlayground { + if err := model.IncreaseTokenQuota(tokenId, tokenKey, tokenConsumed); err != nil { + common.SysLog("error refunding token quota: " + err.Error()) + } + } + }) +} + +// NeedsRefund 返回是否存在需要退还的预扣状态。 +func (s *BillingSession) NeedsRefund() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.needsRefundLocked() +} + +func (s *BillingSession) needsRefundLocked() bool { + if s.settled || s.refunded { + return false + } + if s.tokenConsumed > 0 { + return true + } + // 订阅可能在 tokenConsumed=0 时仍预扣了额度 + if sub, ok := s.funding.(*SubscriptionFunding); ok && sub.preConsumed > 0 { + return true + } + return false +} + +// GetPreConsumedQuota 返回实际预扣的额度。 +func (s *BillingSession) GetPreConsumedQuota() int { + return s.preConsumedQuota +} + +// --------------------------------------------------------------------------- +// PreConsume — 统一预扣费入口(含信任额度旁路) +// --------------------------------------------------------------------------- + +// preConsume 执行预扣费:信任检查 -> 令牌预扣 -> 资金来源预扣。 +// 任一步骤失败时原子回滚已完成的步骤。 +func (s *BillingSession) preConsume(c *gin.Context, quota int) *types.NewAPIError { + effectiveQuota := quota + + // ---- 信任额度旁路 ---- + if s.shouldTrust(c) { + effectiveQuota = 0 + logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足, 信任且不需要预扣费 (funding=%s)", s.relayInfo.UserId, s.funding.Source())) + } else if effectiveQuota > 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 需要预扣费 %s (funding=%s)", s.relayInfo.UserId, logger.FormatQuota(effectiveQuota), s.funding.Source())) + } + + // ---- 1) 预扣令牌额度 ---- + if effectiveQuota > 0 { + if err := PreConsumeTokenQuota(s.relayInfo, effectiveQuota); err != nil { + return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + s.tokenConsumed = effectiveQuota + } + + // ---- 2) 预扣资金来源 ---- + if err := s.funding.PreConsume(effectiveQuota); err != nil { + // 回滚令牌额度 + if s.tokenConsumed > 0 && !s.relayInfo.IsPlayground { + _ = model.IncreaseTokenQuota(s.relayInfo.TokenId, s.relayInfo.TokenKey, s.tokenConsumed) + s.tokenConsumed = 0 + } + errMsg := err.Error() + if strings.Contains(errMsg, "no active subscription") || strings.Contains(errMsg, "subscription quota insufficient") { + return types.NewErrorWithStatusCode(fmt.Errorf("订阅额度不足或未配置订阅: %s", errMsg), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + if strings.Contains(errMsg, "用户额度不足") || strings.Contains(errMsg, "预扣费额度失败") { + return types.NewErrorWithStatusCode(err, types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + } + + s.preConsumedQuota = effectiveQuota + + // ---- 同步 RelayInfo 兼容字段 ---- + s.syncRelayInfo() + + return nil +} + +// shouldTrust 统一信任额度检查,适用于钱包和订阅。 +func (s *BillingSession) shouldTrust(c *gin.Context) bool { + trustQuota := common.GetTrustQuota() + if trustQuota <= 0 { + return false + } + + // 检查令牌是否充足 + tokenTrusted := s.relayInfo.TokenUnlimited + if !tokenTrusted { + tokenQuota := c.GetInt("token_quota") + tokenTrusted = tokenQuota > trustQuota + } + if !tokenTrusted { + return false + } + + switch s.funding.Source() { + case BillingSourceWallet: + return s.relayInfo.UserQuota > trustQuota + case BillingSourceSubscription: + // 订阅暂不支持信任旁路(订阅剩余额度需要额外查询,且预扣开销小) + // 后续可以在此处添加订阅信任逻辑 + return false + default: + return false + } +} + +// syncRelayInfo 将 BillingSession 的状态同步到 RelayInfo 的兼容字段上。 +func (s *BillingSession) syncRelayInfo() { + info := s.relayInfo + info.FinalPreConsumedQuota = s.preConsumedQuota + info.BillingSource = s.funding.Source() + + if sub, ok := s.funding.(*SubscriptionFunding); ok { + info.SubscriptionId = sub.subscriptionId + info.SubscriptionPreConsumed = sub.preConsumed + info.SubscriptionPostDelta = 0 + info.SubscriptionAmountTotal = sub.AmountTotal + info.SubscriptionAmountUsedAfterPreConsume = sub.AmountUsedAfter + info.SubscriptionPlanId = sub.PlanId + info.SubscriptionPlanTitle = sub.PlanTitle + } else { + info.SubscriptionId = 0 + info.SubscriptionPreConsumed = 0 + } +} + +// --------------------------------------------------------------------------- +// NewBillingSession 工厂 — 根据计费偏好创建会话并处理回退 +// --------------------------------------------------------------------------- + +// NewBillingSession 根据用户计费偏好创建 BillingSession,处理 subscription_first / wallet_first 的回退。 +func NewBillingSession(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) (*BillingSession, *types.NewAPIError) { + if relayInfo == nil { + return nil, types.NewError(fmt.Errorf("relayInfo is nil"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) + } + + pref := common.NormalizeBillingPreference(relayInfo.UserSetting.BillingPreference) + + // 钱包路径需要先检查用户额度 + tryWallet := func() (*BillingSession, *types.NewAPIError) { + userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + if err != nil { + return nil, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + } + if userQuota <= 0 { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), + types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, + types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + if userQuota-preConsumedQuota < 0 { + return nil, types.NewErrorWithStatusCode( + fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), + types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, + types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + } + relayInfo.UserQuota = userQuota + + session := &BillingSession{ + relayInfo: relayInfo, + funding: &WalletFunding{userId: relayInfo.UserId}, + } + if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil { + return nil, apiErr + } + return session, nil + } + + trySubscription := func() (*BillingSession, *types.NewAPIError) { + subConsume := int64(preConsumedQuota) + if subConsume <= 0 { + subConsume = 1 + } + session := &BillingSession{ + relayInfo: relayInfo, + funding: &SubscriptionFunding{ + requestId: relayInfo.RequestId, + userId: relayInfo.UserId, + modelName: relayInfo.OriginModelName, + amount: subConsume, + }, + } + if apiErr := session.preConsume(c, preConsumedQuota); apiErr != nil { + return nil, apiErr + } + return session, nil + } + + switch pref { + case "subscription_only": + return trySubscription() + case "wallet_only": + return tryWallet() + case "wallet_first": + session, err := tryWallet() + if err != nil { + if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota { + return trySubscription() + } + return nil, err + } + return session, nil + case "subscription_first": + fallthrough + default: + session, err := trySubscription() + if err != nil { + if err.GetErrorCode() == types.ErrorCodeInsufficientUserQuota { + return tryWallet() + } + return nil, err + } + return session, nil + } +} diff --git a/service/funding_source.go b/service/funding_source.go new file mode 100644 index 000000000..87672419c --- /dev/null +++ b/service/funding_source.go @@ -0,0 +1,137 @@ +package service + +import ( + "time" + + "github.com/QuantumNous/new-api/model" +) + +// --------------------------------------------------------------------------- +// FundingSource — 资金来源接口(钱包 or 订阅) +// --------------------------------------------------------------------------- + +// FundingSource 抽象了预扣费的资金来源。 +type FundingSource interface { + // Source 返回资金来源标识:"wallet" 或 "subscription" + Source() string + // PreConsume 从该资金来源预扣 amount 额度 + PreConsume(amount int) error + // Settle 根据差额调整资金来源(正数补扣,负数退还) + Settle(delta int) error + // Refund 退还所有预扣费 + Refund() error +} + +// --------------------------------------------------------------------------- +// WalletFunding — 钱包资金来源实现 +// --------------------------------------------------------------------------- + +type WalletFunding struct { + userId int + consumed int // 实际预扣的用户额度 +} + +func (w *WalletFunding) Source() string { return BillingSourceWallet } + +func (w *WalletFunding) PreConsume(amount int) error { + if amount <= 0 { + return nil + } + if err := model.DecreaseUserQuota(w.userId, amount); err != nil { + return err + } + w.consumed = amount + return nil +} + +func (w *WalletFunding) Settle(delta int) error { + if delta == 0 { + return nil + } + if delta > 0 { + return model.DecreaseUserQuota(w.userId, delta) + } + return model.IncreaseUserQuota(w.userId, -delta, false) +} + +func (w *WalletFunding) Refund() error { + if w.consumed <= 0 { + return nil + } + return model.IncreaseUserQuota(w.userId, w.consumed, false) +} + +// --------------------------------------------------------------------------- +// SubscriptionFunding — 订阅资金来源实现 +// --------------------------------------------------------------------------- + +type SubscriptionFunding struct { + requestId string + userId int + modelName string + amount int64 // 预扣的订阅额度(subConsume) + subscriptionId int + preConsumed int64 + // 以下字段在 PreConsume 成功后填充,供 RelayInfo 同步使用 + AmountTotal int64 + AmountUsedAfter int64 + PlanId int + PlanTitle string +} + +func (s *SubscriptionFunding) Source() string { return BillingSourceSubscription } + +func (s *SubscriptionFunding) PreConsume(_ int) error { + // amount 参数被忽略,使用内部 s.amount(已在构造时根据 preConsumedQuota 计算) + res, err := model.PreConsumeUserSubscription(s.requestId, s.userId, s.modelName, 0, s.amount) + if err != nil { + return err + } + s.subscriptionId = res.UserSubscriptionId + s.preConsumed = res.PreConsumed + s.AmountTotal = res.AmountTotal + s.AmountUsedAfter = res.AmountUsedAfter + // 获取订阅计划信息 + if planInfo, err := model.GetSubscriptionPlanInfoByUserSubscriptionId(res.UserSubscriptionId); err == nil && planInfo != nil { + s.PlanId = planInfo.PlanId + s.PlanTitle = planInfo.PlanTitle + } + return nil +} + +func (s *SubscriptionFunding) Settle(delta int) error { + if delta == 0 { + return nil + } + return model.PostConsumeUserSubscriptionDelta(s.subscriptionId, int64(delta)) +} + +func (s *SubscriptionFunding) Refund() error { + if s.preConsumed <= 0 { + return nil + } + return refundWithRetry(func() error { + return model.RefundSubscriptionPreConsume(s.requestId) + }) +} + +// refundWithRetry 尝试多次执行退款操作以提高成功率,只能用于基于事务的退款函数!!!!!! +// try to refund with retries, only for refund functions based on transactions!!! +func refundWithRetry(fn func() error) error { + if fn == nil { + return nil + } + const maxAttempts = 3 + var lastErr error + for i := 0; i < maxAttempts; i++ { + if err := fn(); err == nil { + return nil + } else { + lastErr = err + } + if i < maxAttempts-1 { + time.Sleep(time.Duration(200*(i+1)) * time.Millisecond) + } + } + return lastErr +} diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go deleted file mode 100644 index 3b049cb4f..000000000 --- a/service/pre_consume_quota.go +++ /dev/null @@ -1,124 +0,0 @@ -package service - -import ( - "fmt" - "net/http" - "time" - - "github.com/QuantumNous/new-api/common" - "github.com/QuantumNous/new-api/logger" - "github.com/QuantumNous/new-api/model" - relaycommon "github.com/QuantumNous/new-api/relay/common" - "github.com/QuantumNous/new-api/types" - - "github.com/bytedance/gopkg/util/gopool" - "github.com/gin-gonic/gin" -) - -func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { - // Always refund subscription pre-consumed (can be non-zero even when FinalPreConsumedQuota is 0) - needRefundSub := relayInfo.BillingSource == BillingSourceSubscription && relayInfo.SubscriptionId != 0 && relayInfo.SubscriptionPreConsumed > 0 - needRefundToken := relayInfo.FinalPreConsumedQuota != 0 - if !needRefundSub && !needRefundToken { - return - } - logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费(token_quota=%s, subscription=%d)", - relayInfo.UserId, - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - relayInfo.SubscriptionPreConsumed, - )) - gopool.Go(func() { - relayInfoCopy := *relayInfo - if relayInfoCopy.BillingSource == BillingSourceSubscription { - if needRefundSub { - if err := refundWithRetry(func() error { - return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId) - }); err != nil { - common.SysLog("error refund subscription pre-consume: " + err.Error()) - } - } - // refund token quota only - if needRefundToken && !relayInfoCopy.IsPlayground { - _ = model.IncreaseTokenQuota(relayInfoCopy.TokenId, relayInfoCopy.TokenKey, relayInfoCopy.FinalPreConsumedQuota) - } - return - } - - // wallet refund uses existing path (user quota + token quota) - if needRefundToken { - err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false) - if err != nil { - common.SysLog("error return pre-consumed quota: " + err.Error()) - } - } - }) -} - -func refundWithRetry(fn func() error) error { - if fn == nil { - return nil - } - const maxAttempts = 3 - var lastErr error - for i := 0; i < maxAttempts; i++ { - if err := fn(); err == nil { - return nil - } else { - lastErr = err - } - if i < maxAttempts-1 { - time.Sleep(time.Duration(200*(i+1)) * time.Millisecond) - } - } - return lastErr -} - -// PreConsumeQuota checks if the user has enough quota to pre-consume. -// It returns the pre-consumed quota if successful, or an error if not. -func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) - if err != nil { - return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) - } - if userQuota <= 0 { - return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - if userQuota-preConsumedQuota < 0 { - return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - - trustQuota := common.GetTrustQuota() - - relayInfo.UserQuota = userQuota - if userQuota > trustQuota { - // 用户额度充足,判断令牌额度是否充足 - if !relayInfo.TokenUnlimited { - // 非无限令牌,判断令牌额度是否充足 - tokenQuota := c.GetInt("token_quota") - if tokenQuota > trustQuota { - // 令牌额度充足,信任令牌 - preConsumedQuota = 0 - logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota)) - } - } else { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId)) - } - } - - if preConsumedQuota > 0 { - err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) - if err != nil { - return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) - } - err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) - if err != nil { - return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) - } - logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) - } - relayInfo.FinalPreConsumedQuota = preConsumedQuota - return nil -} diff --git a/service/quota.go b/service/quota.go index 951eecec5..5ffc2a723 100644 --- a/service/quota.go +++ b/service/quota.go @@ -307,27 +307,8 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - relayInfo.FinalPreConsumedQuota - - if quotaDelta > 0 { - logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", - logger.FormatQuota(quotaDelta), - logger.FormatQuota(quota), - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - )) - } else if quotaDelta < 0 { - logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", - logger.FormatQuota(-quotaDelta), - logger.FormatQuota(quota), - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - )) - } - - if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) - if err != nil { - logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } + if err := SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling billing: "+err.Error()) } other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, @@ -432,27 +413,8 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, u model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) } - quotaDelta := quota - relayInfo.FinalPreConsumedQuota - - if quotaDelta > 0 { - logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)", - logger.FormatQuota(quotaDelta), - logger.FormatQuota(quota), - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - )) - } else if quotaDelta < 0 { - logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)", - logger.FormatQuota(-quotaDelta), - logger.FormatQuota(quota), - logger.FormatQuota(relayInfo.FinalPreConsumedQuota), - )) - } - - if quotaDelta != 0 { - err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true) - if err != nil { - logger.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } + if err := SettleBilling(ctx, relayInfo, quota); err != nil { + logger.LogError(ctx, "error settling billing: "+err.Error()) } logModel := relayInfo.OriginModelName