mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 05:41:37 +00:00
✨ feat(subscription): harden subscription billing with resets, idempotency, and production-grade stability
Add plan-level quota reset periods and display/reset cadence in admin/UI Enforce natural reset alignment with background reset task and cleanup job Make subscription pre-consume/refund idempotent with request-scoped records and retries Use database time for consistent resets across multi-instance deployments Harden payment callbacks with locking and idempotent order completion Record subscription purchases in topup history and billing logs Optimize subscription queries and add critical composite indexes
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
@@ -298,15 +299,16 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
// Subscription order takes precedence (accept both onetime/subscription types)
|
||||
if model.GetSubscriptionOrderByTradeNo(referenceId) != nil {
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, jsonString(event)); err != nil {
|
||||
log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId)
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Try complete subscription order first
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, jsonString(event)); err == nil {
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId)
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证订单类型,目前只处理一次性付款(充值)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -166,17 +167,19 @@ func sessionCompleted(event stripe.Event) {
|
||||
return
|
||||
}
|
||||
|
||||
// Subscription order takes precedence
|
||||
if model.GetSubscriptionOrderByTradeNo(referenceId) != nil {
|
||||
payload := map[string]any{
|
||||
"customer": customerId,
|
||||
"amount_total": event.GetObjectValue("amount_total"),
|
||||
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
||||
"event_type": string(event.Type),
|
||||
}
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, jsonString(payload)); err != nil {
|
||||
log.Println("complete subscription order failed:", err.Error(), referenceId)
|
||||
}
|
||||
// Try complete subscription order first
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
payload := map[string]any{
|
||||
"customer": customerId,
|
||||
"amount_total": event.GetObjectValue("amount_total"),
|
||||
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
||||
"event_type": string(event.Type),
|
||||
}
|
||||
if err := model.CompleteSubscriptionOrder(referenceId, jsonString(payload)); err == nil {
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
log.Println("complete subscription order failed:", err.Error(), referenceId)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -205,10 +208,12 @@ func sessionExpired(event stripe.Event) {
|
||||
}
|
||||
|
||||
// Subscription order expiration
|
||||
if model.GetSubscriptionOrderByTradeNo(referenceId) != nil {
|
||||
if err := model.ExpireSubscriptionOrder(referenceId); err != nil {
|
||||
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error())
|
||||
}
|
||||
LockOrder(referenceId)
|
||||
defer UnlockOrder(referenceId)
|
||||
if err := model.ExpireSubscriptionOrder(referenceId); err == nil {
|
||||
return
|
||||
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
||||
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
3
main.go
3
main.go
@@ -106,6 +106,9 @@ func main() {
|
||||
// Codex credential auto-refresh check every 10 minutes, refresh when expires within 1 day
|
||||
service.StartCodexCredentialAutoRefreshTask()
|
||||
|
||||
// Subscription quota reset task (daily/weekly/monthly/custom)
|
||||
service.StartSubscriptionQuotaResetTask()
|
||||
|
||||
if common.IsMasterNode && constant.UpdateTask {
|
||||
gopool.Go(func() {
|
||||
controller.UpdateMidjourneyTaskBulk()
|
||||
|
||||
22
model/db_time.go
Normal file
22
model/db_time.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
import "github.com/QuantumNous/new-api/common"
|
||||
|
||||
// GetDBTimestamp returns a UNIX timestamp from database time.
|
||||
// Falls back to application time on error.
|
||||
func GetDBTimestamp() int64 {
|
||||
var ts int64
|
||||
var err error
|
||||
switch {
|
||||
case common.UsingPostgreSQL:
|
||||
err = DB.Raw("SELECT EXTRACT(EPOCH FROM NOW())").Scan(&ts).Error
|
||||
case common.UsingSQLite:
|
||||
err = DB.Raw("SELECT strftime('%s','now')").Scan(&ts).Error
|
||||
default:
|
||||
err = DB.Raw("SELECT UNIX_TIMESTAMP()").Scan(&ts).Error
|
||||
}
|
||||
if err != nil || ts <= 0 {
|
||||
return common.GetTimestamp()
|
||||
}
|
||||
return ts
|
||||
}
|
||||
@@ -273,6 +273,7 @@ func migrateDB() error {
|
||||
&SubscriptionOrder{},
|
||||
&UserSubscription{},
|
||||
&UserSubscriptionItem{},
|
||||
&SubscriptionPreConsumeRecord{},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -312,6 +313,7 @@ func migrateDBFast() error {
|
||||
{&SubscriptionOrder{}, "SubscriptionOrder"},
|
||||
{&UserSubscription{}, "UserSubscription"},
|
||||
{&UserSubscriptionItem{}, "UserSubscriptionItem"},
|
||||
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
|
||||
}
|
||||
// 动态计算migration数量,确保errChan缓冲区足够大
|
||||
errChan := make(chan error, len(migrations))
|
||||
|
||||
@@ -28,6 +28,11 @@ const (
|
||||
SubscriptionResetCustom = "custom"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSubscriptionOrderNotFound = errors.New("subscription order not found")
|
||||
ErrSubscriptionOrderStatusInvalid = errors.New("subscription order status invalid")
|
||||
)
|
||||
|
||||
// Subscription plan
|
||||
type SubscriptionPlan struct {
|
||||
Id int `json:"id"`
|
||||
@@ -122,12 +127,12 @@ func GetSubscriptionOrderByTradeNo(tradeNo string) *SubscriptionOrder {
|
||||
// User subscription instance
|
||||
type UserSubscription struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
UserId int `json:"user_id" gorm:"index;index:idx_user_sub_active,priority:1"`
|
||||
PlanId int `json:"plan_id" gorm:"index"`
|
||||
|
||||
StartTime int64 `json:"start_time" gorm:"bigint"`
|
||||
EndTime int64 `json:"end_time" gorm:"bigint;index"`
|
||||
Status string `json:"status" gorm:"type:varchar(32);index"` // active/expired/cancelled
|
||||
EndTime int64 `json:"end_time" gorm:"bigint;index;index:idx_user_sub_active,priority:3"`
|
||||
Status string `json:"status" gorm:"type:varchar(32);index;index:idx_user_sub_active,priority:2"` // active/expired/cancelled
|
||||
|
||||
Source string `json:"source" gorm:"type:varchar(32);default:'order'"` // order/admin
|
||||
|
||||
@@ -149,9 +154,9 @@ func (s *UserSubscription) BeforeUpdate(tx *gorm.DB) error {
|
||||
|
||||
type UserSubscriptionItem struct {
|
||||
Id int `json:"id"`
|
||||
UserSubscriptionId int `json:"user_subscription_id" gorm:"index"`
|
||||
ModelName string `json:"model_name" gorm:"type:varchar(128);index"`
|
||||
QuotaType int `json:"quota_type" gorm:"type:int;index"`
|
||||
UserSubscriptionId int `json:"user_subscription_id" gorm:"index;index:idx_sub_item_model_quota,priority:3"`
|
||||
ModelName string `json:"model_name" gorm:"type:varchar(128);index;index:idx_sub_item_model_quota,priority:1"`
|
||||
QuotaType int `json:"quota_type" gorm:"type:int;index;index:idx_sub_item_model_quota,priority:2"`
|
||||
AmountTotal int64 `json:"amount_total" gorm:"type:bigint;not null;default:0"`
|
||||
AmountUsed int64 `json:"amount_used" gorm:"type:bigint;not null;default:0"`
|
||||
LastResetTime int64 `json:"last_reset_time" gorm:"type:bigint;default:0"`
|
||||
@@ -209,11 +214,22 @@ func calcNextResetTime(base time.Time, plan *SubscriptionPlan, endUnix int64) in
|
||||
var next time.Time
|
||||
switch period {
|
||||
case SubscriptionResetDaily:
|
||||
next = base.Add(24 * time.Hour)
|
||||
next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
|
||||
AddDate(0, 0, 1)
|
||||
case SubscriptionResetWeekly:
|
||||
next = base.AddDate(0, 0, 7)
|
||||
// Align to next Monday 00:00
|
||||
weekday := int(base.Weekday()) // Sunday=0
|
||||
// Convert to Monday=1..Sunday=7
|
||||
if weekday == 0 {
|
||||
weekday = 7
|
||||
}
|
||||
daysUntil := 8 - weekday
|
||||
next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
|
||||
AddDate(0, 0, daysUntil)
|
||||
case SubscriptionResetMonthly:
|
||||
next = base.AddDate(0, 1, 0)
|
||||
// Align to first day of next month 00:00
|
||||
next = time.Date(base.Year(), base.Month(), 1, 0, 0, 0, 0, base.Location()).
|
||||
AddDate(0, 1, 0)
|
||||
case SubscriptionResetCustom:
|
||||
if plan.QuotaResetCustomSeconds <= 0 {
|
||||
return 0
|
||||
@@ -260,7 +276,8 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio
|
||||
if userId <= 0 {
|
||||
return nil, errors.New("invalid user id")
|
||||
}
|
||||
now := time.Now()
|
||||
nowUnix := GetDBTimestamp()
|
||||
now := time.Unix(nowUnix, 0)
|
||||
endUnix, err := calcPlanEndTime(now, plan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -325,13 +342,13 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error {
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
var order SubscriptionOrder
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return errors.New("subscription order not found")
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if order.Status == common.TopUpStatusSuccess {
|
||||
return nil
|
||||
}
|
||||
if order.Status != common.TopUpStatusPending {
|
||||
return errors.New("subscription order status invalid")
|
||||
return ErrSubscriptionOrderStatusInvalid
|
||||
}
|
||||
plan, err := GetSubscriptionPlanById(order.PlanId)
|
||||
if err != nil {
|
||||
@@ -416,7 +433,7 @@ func ExpireSubscriptionOrder(tradeNo string) error {
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
var order SubscriptionOrder
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
|
||||
return errors.New("subscription order not found")
|
||||
return ErrSubscriptionOrderNotFound
|
||||
}
|
||||
if order.Status != common.TopUpStatusPending {
|
||||
return nil
|
||||
@@ -455,16 +472,7 @@ func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := make([]SubscriptionSummary, 0, len(subs))
|
||||
for _, sub := range subs {
|
||||
var items []UserSubscriptionItem
|
||||
if err := DB.Where("user_subscription_id = ?", sub.Id).Find(&items).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
subCopy := sub
|
||||
result = append(result, SubscriptionSummary{Subscription: &subCopy, Items: items})
|
||||
}
|
||||
return result, nil
|
||||
return buildSubscriptionSummaries(subs)
|
||||
}
|
||||
|
||||
// GetAllUserSubscriptions returns all subscriptions (active and expired) for a user.
|
||||
@@ -479,14 +487,32 @@ func GetAllUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buildSubscriptionSummaries(subs)
|
||||
}
|
||||
|
||||
func buildSubscriptionSummaries(subs []UserSubscription) ([]SubscriptionSummary, error) {
|
||||
if len(subs) == 0 {
|
||||
return []SubscriptionSummary{}, nil
|
||||
}
|
||||
subIds := make([]int, 0, len(subs))
|
||||
for _, sub := range subs {
|
||||
subIds = append(subIds, sub.Id)
|
||||
}
|
||||
var items []UserSubscriptionItem
|
||||
if err := DB.Where("user_subscription_id IN ?", subIds).Find(&items).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
itemsMap := make(map[int][]UserSubscriptionItem, len(subIds))
|
||||
for _, it := range items {
|
||||
itemsMap[it.UserSubscriptionId] = append(itemsMap[it.UserSubscriptionId], it)
|
||||
}
|
||||
result := make([]SubscriptionSummary, 0, len(subs))
|
||||
for _, sub := range subs {
|
||||
var items []UserSubscriptionItem
|
||||
if err := DB.Where("user_subscription_id = ?", sub.Id).Find(&items).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
subCopy := sub
|
||||
result = append(result, SubscriptionSummary{Subscription: &subCopy, Items: items})
|
||||
result = append(result, SubscriptionSummary{
|
||||
Subscription: &subCopy,
|
||||
Items: itemsMap[sub.Id],
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -539,6 +565,30 @@ type SubscriptionPreConsumeResult struct {
|
||||
AmountUsedAfter int64
|
||||
}
|
||||
|
||||
// SubscriptionPreConsumeRecord stores idempotent pre-consume operations per request.
|
||||
type SubscriptionPreConsumeRecord struct {
|
||||
Id int `json:"id"`
|
||||
RequestId string `json:"request_id" gorm:"type:varchar(64);uniqueIndex"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
UserSubscriptionItemId int `json:"user_subscription_item_id" gorm:"index"`
|
||||
PreConsumed int64 `json:"pre_consumed" gorm:"type:bigint;not null;default:0"`
|
||||
Status string `json:"status" gorm:"type:varchar(32);index"` // consumed/refunded
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint"`
|
||||
UpdatedAt int64 `json:"updated_at" gorm:"bigint;index"`
|
||||
}
|
||||
|
||||
func (r *SubscriptionPreConsumeRecord) BeforeCreate(tx *gorm.DB) error {
|
||||
now := common.GetTimestamp()
|
||||
r.CreatedAt = now
|
||||
r.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SubscriptionPreConsumeRecord) BeforeUpdate(tx *gorm.DB) error {
|
||||
r.UpdatedAt = common.GetTimestamp()
|
||||
return nil
|
||||
}
|
||||
|
||||
func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now int64) error {
|
||||
if tx == nil || item == nil {
|
||||
return errors.New("invalid reset args")
|
||||
@@ -587,20 +637,43 @@ func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now i
|
||||
|
||||
// PreConsumeUserSubscription finds a valid active subscription item and increments amount_used.
|
||||
// quotaType=0 => consume quota units; quotaType=1 => consume request count (usually 1).
|
||||
func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) {
|
||||
func PreConsumeUserSubscription(requestId string, userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) {
|
||||
if userId <= 0 {
|
||||
return nil, errors.New("invalid userId")
|
||||
}
|
||||
if strings.TrimSpace(requestId) == "" {
|
||||
return nil, errors.New("requestId is empty")
|
||||
}
|
||||
if modelName == "" {
|
||||
return nil, errors.New("modelName is empty")
|
||||
}
|
||||
if amount <= 0 {
|
||||
return nil, errors.New("amount must be > 0")
|
||||
}
|
||||
now := common.GetTimestamp()
|
||||
now := GetDBTimestamp()
|
||||
|
||||
returnValue := &SubscriptionPreConsumeResult{}
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
var existing SubscriptionPreConsumeRecord
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
Where("request_id = ?", requestId).First(&existing).Error; err == nil {
|
||||
if existing.Status == "refunded" {
|
||||
return errors.New("subscription pre-consume already refunded")
|
||||
}
|
||||
var item UserSubscriptionItem
|
||||
if err := tx.Where("id = ?", existing.UserSubscriptionItemId).First(&item).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
returnValue.UserSubscriptionId = item.UserSubscriptionId
|
||||
returnValue.ItemId = item.Id
|
||||
returnValue.QuotaType = item.QuotaType
|
||||
returnValue.PreConsumed = existing.PreConsumed
|
||||
returnValue.AmountTotal = item.AmountTotal
|
||||
returnValue.AmountUsedBefore = item.AmountUsed
|
||||
returnValue.AmountUsedAfter = item.AmountUsed
|
||||
return nil
|
||||
}
|
||||
|
||||
var item UserSubscriptionItem
|
||||
// lock item row; join to ensure subscription still active
|
||||
q := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
@@ -609,17 +682,14 @@ func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amo
|
||||
Joins("JOIN user_subscriptions ON user_subscriptions.id = user_subscription_items.user_subscription_id").
|
||||
Where("user_subscriptions.user_id = ? AND user_subscriptions.status = ? AND user_subscriptions.end_time > ?", userId, "active", now).
|
||||
Where("user_subscription_items.model_name = ? AND user_subscription_items.quota_type = ?", modelName, quotaType).
|
||||
Order("user_subscriptions.end_time desc, user_subscriptions.id desc, user_subscription_items.id desc")
|
||||
Order("user_subscriptions.end_time desc, user_subscriptions.id desc, user_subscription_items.id desc").
|
||||
Limit(1)
|
||||
if err := q.First(&item).Error; err != nil {
|
||||
return errors.New("no active subscription item for this model")
|
||||
}
|
||||
if err := maybeResetSubscriptionItemTx(tx, &item, now); err != nil {
|
||||
return err
|
||||
}
|
||||
// reload item after potential reset
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", item.Id).First(&item).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
usedBefore := item.AmountUsed
|
||||
remain := item.AmountTotal - usedBefore
|
||||
if remain < amount {
|
||||
@@ -629,6 +699,16 @@ func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amo
|
||||
if err := tx.Save(&item).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
record := &SubscriptionPreConsumeRecord{
|
||||
RequestId: requestId,
|
||||
UserId: userId,
|
||||
UserSubscriptionItemId: item.Id,
|
||||
PreConsumed: amount,
|
||||
Status: "consumed",
|
||||
}
|
||||
if err := tx.Create(record).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
returnValue.UserSubscriptionId = item.UserSubscriptionId
|
||||
returnValue.ItemId = item.Id
|
||||
returnValue.QuotaType = item.QuotaType
|
||||
@@ -644,6 +724,80 @@ func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amo
|
||||
return returnValue, nil
|
||||
}
|
||||
|
||||
// RefundSubscriptionPreConsume is idempotent and refunds pre-consumed subscription quota by requestId.
|
||||
func RefundSubscriptionPreConsume(requestId string) error {
|
||||
if strings.TrimSpace(requestId) == "" {
|
||||
return errors.New("requestId is empty")
|
||||
}
|
||||
return DB.Transaction(func(tx *gorm.DB) error {
|
||||
var record SubscriptionPreConsumeRecord
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
Where("request_id = ?", requestId).First(&record).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Status == "refunded" {
|
||||
return nil
|
||||
}
|
||||
if record.PreConsumed <= 0 {
|
||||
record.Status = "refunded"
|
||||
return tx.Save(&record).Error
|
||||
}
|
||||
if err := PostConsumeUserSubscriptionDelta(record.UserSubscriptionItemId, -record.PreConsumed); err != nil {
|
||||
return err
|
||||
}
|
||||
record.Status = "refunded"
|
||||
return tx.Save(&record).Error
|
||||
})
|
||||
}
|
||||
|
||||
// ResetDueSubscriptionItems resets items whose next_reset_time has passed.
|
||||
func ResetDueSubscriptionItems(limit int) (int, error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
now := GetDBTimestamp()
|
||||
var items []UserSubscriptionItem
|
||||
if err := DB.Where("next_reset_time > 0 AND next_reset_time <= ?", now).
|
||||
Order("next_reset_time asc").
|
||||
Limit(limit).
|
||||
Find(&items).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
resetCount := 0
|
||||
for _, it := range items {
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
var item UserSubscriptionItem
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
Where("id = ? AND next_reset_time > 0 AND next_reset_time <= ?", it.Id, now).
|
||||
First(&item).Error; err != nil {
|
||||
return nil
|
||||
}
|
||||
if err := maybeResetSubscriptionItemTx(tx, &item, now); err != nil {
|
||||
return err
|
||||
}
|
||||
resetCount++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return resetCount, err
|
||||
}
|
||||
}
|
||||
return resetCount, nil
|
||||
}
|
||||
|
||||
// CleanupSubscriptionPreConsumeRecords removes old idempotency records to keep table small.
|
||||
func CleanupSubscriptionPreConsumeRecords(olderThanSeconds int64) (int64, error) {
|
||||
if olderThanSeconds <= 0 {
|
||||
olderThanSeconds = 7 * 24 * 3600
|
||||
}
|
||||
cutoff := GetDBTimestamp() - olderThanSeconds
|
||||
res := DB.Where("updated_at < ?", cutoff).Delete(&SubscriptionPreConsumeRecord{})
|
||||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
type SubscriptionPlanInfo struct {
|
||||
PlanId int
|
||||
PlanTitle string
|
||||
|
||||
@@ -129,6 +129,8 @@ type RelayInfo struct {
|
||||
// SubscriptionPlanId / SubscriptionPlanTitle are used for logging/UI display.
|
||||
SubscriptionPlanId int
|
||||
SubscriptionPlanTitle string
|
||||
// RequestId is used for idempotent pre-consume/refund
|
||||
RequestId string
|
||||
// SubscriptionAmountTotal / SubscriptionAmountUsedAfterPreConsume are used to compute remaining in logs.
|
||||
SubscriptionAmountTotal int64
|
||||
SubscriptionAmountUsedAfterPreConsume int64
|
||||
@@ -418,9 +420,14 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
|
||||
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
|
||||
reqId := common.GetContextKeyString(c, common.RequestIdKey)
|
||||
if reqId == "" {
|
||||
reqId = common.GetTimeString() + common.GetRandomString(8)
|
||||
}
|
||||
info := &RelayInfo{
|
||||
Request: request,
|
||||
|
||||
RequestId: reqId,
|
||||
UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId),
|
||||
UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup),
|
||||
UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup),
|
||||
|
||||
@@ -56,7 +56,7 @@ func PreConsumeBilling(c *gin.Context, preConsumedQuota int, relayInfo *relaycom
|
||||
}
|
||||
}
|
||||
|
||||
res, err := model.PreConsumeUserSubscription(relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume)
|
||||
res, err := model.PreConsumeUserSubscription(relayInfo.RequestId, relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume)
|
||||
if err != nil {
|
||||
// revert token pre-consume when subscription fails
|
||||
if preConsumedQuota > 0 && !relayInfo.IsPlayground {
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
@@ -30,7 +31,9 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
relayInfoCopy := *relayInfo
|
||||
if relayInfoCopy.BillingSource == BillingSourceSubscription {
|
||||
if needRefundSub {
|
||||
_ = model.PostConsumeUserSubscriptionDelta(relayInfoCopy.SubscriptionItemId, -relayInfoCopy.SubscriptionPreConsumed)
|
||||
refundWithRetry(func() error {
|
||||
return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId)
|
||||
})
|
||||
}
|
||||
// refund token quota only
|
||||
if needRefundToken && !relayInfoCopy.IsPlayground {
|
||||
@@ -49,6 +52,21 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
})
|
||||
}
|
||||
|
||||
func refundWithRetry(fn func() error) {
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
const maxAttempts = 3
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
if err := fn(); err == nil {
|
||||
return
|
||||
}
|
||||
if i < maxAttempts-1 {
|
||||
time.Sleep(time.Duration(200*(i+1)) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||
// It returns the pre-consumed quota if successful, or an error if not.
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
|
||||
78
service/subscription_reset_task.go
Normal file
78
service/subscription_reset_task.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
)
|
||||
|
||||
const (
|
||||
subscriptionResetTickInterval = 1 * time.Minute
|
||||
subscriptionResetBatchSize = 300
|
||||
subscriptionCleanupInterval = 30 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
subscriptionResetOnce sync.Once
|
||||
subscriptionResetRunning atomic.Bool
|
||||
subscriptionCleanupLast atomic.Int64
|
||||
)
|
||||
|
||||
func StartSubscriptionQuotaResetTask() {
|
||||
subscriptionResetOnce.Do(func() {
|
||||
if !common.IsMasterNode {
|
||||
return
|
||||
}
|
||||
gopool.Go(func() {
|
||||
logger.LogInfo(context.Background(), fmt.Sprintf("subscription quota reset task started: tick=%s", subscriptionResetTickInterval))
|
||||
ticker := time.NewTicker(subscriptionResetTickInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runSubscriptionQuotaResetOnce()
|
||||
for range ticker.C {
|
||||
runSubscriptionQuotaResetOnce()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func runSubscriptionQuotaResetOnce() {
|
||||
if !subscriptionResetRunning.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
defer subscriptionResetRunning.Store(false)
|
||||
|
||||
ctx := context.Background()
|
||||
totalReset := 0
|
||||
for {
|
||||
n, err := model.ResetDueSubscriptionItems(subscriptionResetBatchSize)
|
||||
if err != nil {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("subscription quota reset task failed: %v", err))
|
||||
return
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
totalReset += n
|
||||
if n < subscriptionResetBatchSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
lastCleanup := time.Unix(subscriptionCleanupLast.Load(), 0)
|
||||
if time.Since(lastCleanup) >= subscriptionCleanupInterval {
|
||||
if _, err := model.CleanupSubscriptionPreConsumeRecords(7 * 24 * 3600); err == nil {
|
||||
subscriptionCleanupLast.Store(time.Now().Unix())
|
||||
}
|
||||
}
|
||||
if totalReset > 0 && common.DebugEnabled {
|
||||
logger.LogDebug(ctx, "subscription quota reset: reset_count=%d", totalReset)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user