diff --git a/controller/topup_creem.go b/controller/topup_creem.go index fab360338..cd83af045 100644 --- a/controller/topup_creem.go +++ b/controller/topup_creem.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" @@ -298,15 +299,16 @@ func handleCheckoutCompleted(c *gin.Context, event *CreemWebhookEvent) { return } - // Subscription order takes precedence (accept both onetime/subscription types) - if model.GetSubscriptionOrderByTradeNo(referenceId) != nil { - if err := model.CompleteSubscriptionOrder(referenceId, jsonString(event)); err != nil { - log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId) - c.AbortWithStatus(http.StatusInternalServerError) - return - } + // Try complete subscription order first + LockOrder(referenceId) + defer UnlockOrder(referenceId) + if err := model.CompleteSubscriptionOrder(referenceId, jsonString(event)); err == nil { c.Status(http.StatusOK) return + } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { + log.Printf("Creem订阅订单处理失败: %s, 订单号: %s", err.Error(), referenceId) + c.AbortWithStatus(http.StatusInternalServerError) + return } // 验证订单类型,目前只处理一次性付款(充值) diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index 4a4c4102c..995a50af3 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -1,6 +1,7 @@ package controller import ( + "errors" "fmt" "io" "log" @@ -166,17 +167,19 @@ func sessionCompleted(event stripe.Event) { return } - // Subscription order takes precedence - if model.GetSubscriptionOrderByTradeNo(referenceId) != nil { - payload := map[string]any{ - "customer": customerId, - "amount_total": event.GetObjectValue("amount_total"), - "currency": strings.ToUpper(event.GetObjectValue("currency")), - "event_type": string(event.Type), - } - if err := model.CompleteSubscriptionOrder(referenceId, jsonString(payload)); err != nil { - log.Println("complete subscription order failed:", err.Error(), referenceId) - } + // Try complete subscription order first + LockOrder(referenceId) + defer UnlockOrder(referenceId) + payload := map[string]any{ + "customer": customerId, + "amount_total": event.GetObjectValue("amount_total"), + "currency": strings.ToUpper(event.GetObjectValue("currency")), + "event_type": string(event.Type), + } + if err := model.CompleteSubscriptionOrder(referenceId, jsonString(payload)); err == nil { + return + } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { + log.Println("complete subscription order failed:", err.Error(), referenceId) return } @@ -205,10 +208,12 @@ func sessionExpired(event stripe.Event) { } // Subscription order expiration - if model.GetSubscriptionOrderByTradeNo(referenceId) != nil { - if err := model.ExpireSubscriptionOrder(referenceId); err != nil { - log.Println("过期订阅订单失败", referenceId, ", err:", err.Error()) - } + LockOrder(referenceId) + defer UnlockOrder(referenceId) + if err := model.ExpireSubscriptionOrder(referenceId); err == nil { + return + } else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) { + log.Println("过期订阅订单失败", referenceId, ", err:", err.Error()) return } diff --git a/main.go b/main.go index 23953b877..0964530e1 100644 --- a/main.go +++ b/main.go @@ -106,6 +106,9 @@ func main() { // Codex credential auto-refresh check every 10 minutes, refresh when expires within 1 day service.StartCodexCredentialAutoRefreshTask() + // Subscription quota reset task (daily/weekly/monthly/custom) + service.StartSubscriptionQuotaResetTask() + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/model/db_time.go b/model/db_time.go new file mode 100644 index 000000000..a01a33187 --- /dev/null +++ b/model/db_time.go @@ -0,0 +1,22 @@ +package model + +import "github.com/QuantumNous/new-api/common" + +// GetDBTimestamp returns a UNIX timestamp from database time. +// Falls back to application time on error. +func GetDBTimestamp() int64 { + var ts int64 + var err error + switch { + case common.UsingPostgreSQL: + err = DB.Raw("SELECT EXTRACT(EPOCH FROM NOW())").Scan(&ts).Error + case common.UsingSQLite: + err = DB.Raw("SELECT strftime('%s','now')").Scan(&ts).Error + default: + err = DB.Raw("SELECT UNIX_TIMESTAMP()").Scan(&ts).Error + } + if err != nil || ts <= 0 { + return common.GetTimestamp() + } + return ts +} diff --git a/model/main.go b/model/main.go index dcd6c8841..ab5c7d714 100644 --- a/model/main.go +++ b/model/main.go @@ -273,6 +273,7 @@ func migrateDB() error { &SubscriptionOrder{}, &UserSubscription{}, &UserSubscriptionItem{}, + &SubscriptionPreConsumeRecord{}, ) if err != nil { return err @@ -312,6 +313,7 @@ func migrateDBFast() error { {&SubscriptionOrder{}, "SubscriptionOrder"}, {&UserSubscription{}, "UserSubscription"}, {&UserSubscriptionItem{}, "UserSubscriptionItem"}, + {&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"}, } // 动态计算migration数量,确保errChan缓冲区足够大 errChan := make(chan error, len(migrations)) diff --git a/model/subscription.go b/model/subscription.go index eda7475cb..590e82c3c 100644 --- a/model/subscription.go +++ b/model/subscription.go @@ -28,6 +28,11 @@ const ( SubscriptionResetCustom = "custom" ) +var ( + ErrSubscriptionOrderNotFound = errors.New("subscription order not found") + ErrSubscriptionOrderStatusInvalid = errors.New("subscription order status invalid") +) + // Subscription plan type SubscriptionPlan struct { Id int `json:"id"` @@ -122,12 +127,12 @@ func GetSubscriptionOrderByTradeNo(tradeNo string) *SubscriptionOrder { // User subscription instance type UserSubscription struct { Id int `json:"id"` - UserId int `json:"user_id" gorm:"index"` + UserId int `json:"user_id" gorm:"index;index:idx_user_sub_active,priority:1"` PlanId int `json:"plan_id" gorm:"index"` StartTime int64 `json:"start_time" gorm:"bigint"` - EndTime int64 `json:"end_time" gorm:"bigint;index"` - Status string `json:"status" gorm:"type:varchar(32);index"` // active/expired/cancelled + EndTime int64 `json:"end_time" gorm:"bigint;index;index:idx_user_sub_active,priority:3"` + Status string `json:"status" gorm:"type:varchar(32);index;index:idx_user_sub_active,priority:2"` // active/expired/cancelled Source string `json:"source" gorm:"type:varchar(32);default:'order'"` // order/admin @@ -149,9 +154,9 @@ func (s *UserSubscription) BeforeUpdate(tx *gorm.DB) error { type UserSubscriptionItem struct { Id int `json:"id"` - UserSubscriptionId int `json:"user_subscription_id" gorm:"index"` - ModelName string `json:"model_name" gorm:"type:varchar(128);index"` - QuotaType int `json:"quota_type" gorm:"type:int;index"` + UserSubscriptionId int `json:"user_subscription_id" gorm:"index;index:idx_sub_item_model_quota,priority:3"` + ModelName string `json:"model_name" gorm:"type:varchar(128);index;index:idx_sub_item_model_quota,priority:1"` + QuotaType int `json:"quota_type" gorm:"type:int;index;index:idx_sub_item_model_quota,priority:2"` AmountTotal int64 `json:"amount_total" gorm:"type:bigint;not null;default:0"` AmountUsed int64 `json:"amount_used" gorm:"type:bigint;not null;default:0"` LastResetTime int64 `json:"last_reset_time" gorm:"type:bigint;default:0"` @@ -209,11 +214,22 @@ func calcNextResetTime(base time.Time, plan *SubscriptionPlan, endUnix int64) in var next time.Time switch period { case SubscriptionResetDaily: - next = base.Add(24 * time.Hour) + next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()). + AddDate(0, 0, 1) case SubscriptionResetWeekly: - next = base.AddDate(0, 0, 7) + // Align to next Monday 00:00 + weekday := int(base.Weekday()) // Sunday=0 + // Convert to Monday=1..Sunday=7 + if weekday == 0 { + weekday = 7 + } + daysUntil := 8 - weekday + next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()). + AddDate(0, 0, daysUntil) case SubscriptionResetMonthly: - next = base.AddDate(0, 1, 0) + // Align to first day of next month 00:00 + next = time.Date(base.Year(), base.Month(), 1, 0, 0, 0, 0, base.Location()). + AddDate(0, 1, 0) case SubscriptionResetCustom: if plan.QuotaResetCustomSeconds <= 0 { return 0 @@ -260,7 +276,8 @@ func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *Subscriptio if userId <= 0 { return nil, errors.New("invalid user id") } - now := time.Now() + nowUnix := GetDBTimestamp() + now := time.Unix(nowUnix, 0) endUnix, err := calcPlanEndTime(now, plan) if err != nil { return nil, err @@ -325,13 +342,13 @@ func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error { err := DB.Transaction(func(tx *gorm.DB) error { var order SubscriptionOrder if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { - return errors.New("subscription order not found") + return ErrSubscriptionOrderNotFound } if order.Status == common.TopUpStatusSuccess { return nil } if order.Status != common.TopUpStatusPending { - return errors.New("subscription order status invalid") + return ErrSubscriptionOrderStatusInvalid } plan, err := GetSubscriptionPlanById(order.PlanId) if err != nil { @@ -416,7 +433,7 @@ func ExpireSubscriptionOrder(tradeNo string) error { return DB.Transaction(func(tx *gorm.DB) error { var order SubscriptionOrder if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil { - return errors.New("subscription order not found") + return ErrSubscriptionOrderNotFound } if order.Status != common.TopUpStatusPending { return nil @@ -455,16 +472,7 @@ func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) { if err != nil { return nil, err } - result := make([]SubscriptionSummary, 0, len(subs)) - for _, sub := range subs { - var items []UserSubscriptionItem - if err := DB.Where("user_subscription_id = ?", sub.Id).Find(&items).Error; err != nil { - continue - } - subCopy := sub - result = append(result, SubscriptionSummary{Subscription: &subCopy, Items: items}) - } - return result, nil + return buildSubscriptionSummaries(subs) } // GetAllUserSubscriptions returns all subscriptions (active and expired) for a user. @@ -479,14 +487,32 @@ func GetAllUserSubscriptions(userId int) ([]SubscriptionSummary, error) { if err != nil { return nil, err } + return buildSubscriptionSummaries(subs) +} + +func buildSubscriptionSummaries(subs []UserSubscription) ([]SubscriptionSummary, error) { + if len(subs) == 0 { + return []SubscriptionSummary{}, nil + } + subIds := make([]int, 0, len(subs)) + for _, sub := range subs { + subIds = append(subIds, sub.Id) + } + var items []UserSubscriptionItem + if err := DB.Where("user_subscription_id IN ?", subIds).Find(&items).Error; err != nil { + return nil, err + } + itemsMap := make(map[int][]UserSubscriptionItem, len(subIds)) + for _, it := range items { + itemsMap[it.UserSubscriptionId] = append(itemsMap[it.UserSubscriptionId], it) + } result := make([]SubscriptionSummary, 0, len(subs)) for _, sub := range subs { - var items []UserSubscriptionItem - if err := DB.Where("user_subscription_id = ?", sub.Id).Find(&items).Error; err != nil { - continue - } subCopy := sub - result = append(result, SubscriptionSummary{Subscription: &subCopy, Items: items}) + result = append(result, SubscriptionSummary{ + Subscription: &subCopy, + Items: itemsMap[sub.Id], + }) } return result, nil } @@ -539,6 +565,30 @@ type SubscriptionPreConsumeResult struct { AmountUsedAfter int64 } +// SubscriptionPreConsumeRecord stores idempotent pre-consume operations per request. +type SubscriptionPreConsumeRecord struct { + Id int `json:"id"` + RequestId string `json:"request_id" gorm:"type:varchar(64);uniqueIndex"` + UserId int `json:"user_id" gorm:"index"` + UserSubscriptionItemId int `json:"user_subscription_item_id" gorm:"index"` + PreConsumed int64 `json:"pre_consumed" gorm:"type:bigint;not null;default:0"` + Status string `json:"status" gorm:"type:varchar(32);index"` // consumed/refunded + CreatedAt int64 `json:"created_at" gorm:"bigint"` + UpdatedAt int64 `json:"updated_at" gorm:"bigint;index"` +} + +func (r *SubscriptionPreConsumeRecord) BeforeCreate(tx *gorm.DB) error { + now := common.GetTimestamp() + r.CreatedAt = now + r.UpdatedAt = now + return nil +} + +func (r *SubscriptionPreConsumeRecord) BeforeUpdate(tx *gorm.DB) error { + r.UpdatedAt = common.GetTimestamp() + return nil +} + func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now int64) error { if tx == nil || item == nil { return errors.New("invalid reset args") @@ -587,20 +637,43 @@ func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now i // PreConsumeUserSubscription finds a valid active subscription item and increments amount_used. // quotaType=0 => consume quota units; quotaType=1 => consume request count (usually 1). -func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) { +func PreConsumeUserSubscription(requestId string, userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) { if userId <= 0 { return nil, errors.New("invalid userId") } + if strings.TrimSpace(requestId) == "" { + return nil, errors.New("requestId is empty") + } if modelName == "" { return nil, errors.New("modelName is empty") } if amount <= 0 { return nil, errors.New("amount must be > 0") } - now := common.GetTimestamp() + now := GetDBTimestamp() returnValue := &SubscriptionPreConsumeResult{} err := DB.Transaction(func(tx *gorm.DB) error { + var existing SubscriptionPreConsumeRecord + if err := tx.Set("gorm:query_option", "FOR UPDATE"). + Where("request_id = ?", requestId).First(&existing).Error; err == nil { + if existing.Status == "refunded" { + return errors.New("subscription pre-consume already refunded") + } + var item UserSubscriptionItem + if err := tx.Where("id = ?", existing.UserSubscriptionItemId).First(&item).Error; err != nil { + return err + } + returnValue.UserSubscriptionId = item.UserSubscriptionId + returnValue.ItemId = item.Id + returnValue.QuotaType = item.QuotaType + returnValue.PreConsumed = existing.PreConsumed + returnValue.AmountTotal = item.AmountTotal + returnValue.AmountUsedBefore = item.AmountUsed + returnValue.AmountUsedAfter = item.AmountUsed + return nil + } + var item UserSubscriptionItem // lock item row; join to ensure subscription still active q := tx.Set("gorm:query_option", "FOR UPDATE"). @@ -609,17 +682,14 @@ func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amo Joins("JOIN user_subscriptions ON user_subscriptions.id = user_subscription_items.user_subscription_id"). Where("user_subscriptions.user_id = ? AND user_subscriptions.status = ? AND user_subscriptions.end_time > ?", userId, "active", now). Where("user_subscription_items.model_name = ? AND user_subscription_items.quota_type = ?", modelName, quotaType). - Order("user_subscriptions.end_time desc, user_subscriptions.id desc, user_subscription_items.id desc") + Order("user_subscriptions.end_time desc, user_subscriptions.id desc, user_subscription_items.id desc"). + Limit(1) if err := q.First(&item).Error; err != nil { return errors.New("no active subscription item for this model") } if err := maybeResetSubscriptionItemTx(tx, &item, now); err != nil { return err } - // reload item after potential reset - if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", item.Id).First(&item).Error; err != nil { - return err - } usedBefore := item.AmountUsed remain := item.AmountTotal - usedBefore if remain < amount { @@ -629,6 +699,16 @@ func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amo if err := tx.Save(&item).Error; err != nil { return err } + record := &SubscriptionPreConsumeRecord{ + RequestId: requestId, + UserId: userId, + UserSubscriptionItemId: item.Id, + PreConsumed: amount, + Status: "consumed", + } + if err := tx.Create(record).Error; err != nil { + return err + } returnValue.UserSubscriptionId = item.UserSubscriptionId returnValue.ItemId = item.Id returnValue.QuotaType = item.QuotaType @@ -644,6 +724,80 @@ func PreConsumeUserSubscription(userId int, modelName string, quotaType int, amo return returnValue, nil } +// RefundSubscriptionPreConsume is idempotent and refunds pre-consumed subscription quota by requestId. +func RefundSubscriptionPreConsume(requestId string) error { + if strings.TrimSpace(requestId) == "" { + return errors.New("requestId is empty") + } + return DB.Transaction(func(tx *gorm.DB) error { + var record SubscriptionPreConsumeRecord + if err := tx.Set("gorm:query_option", "FOR UPDATE"). + Where("request_id = ?", requestId).First(&record).Error; err != nil { + return err + } + if record.Status == "refunded" { + return nil + } + if record.PreConsumed <= 0 { + record.Status = "refunded" + return tx.Save(&record).Error + } + if err := PostConsumeUserSubscriptionDelta(record.UserSubscriptionItemId, -record.PreConsumed); err != nil { + return err + } + record.Status = "refunded" + return tx.Save(&record).Error + }) +} + +// ResetDueSubscriptionItems resets items whose next_reset_time has passed. +func ResetDueSubscriptionItems(limit int) (int, error) { + if limit <= 0 { + limit = 200 + } + now := GetDBTimestamp() + var items []UserSubscriptionItem + if err := DB.Where("next_reset_time > 0 AND next_reset_time <= ?", now). + Order("next_reset_time asc"). + Limit(limit). + Find(&items).Error; err != nil { + return 0, err + } + if len(items) == 0 { + return 0, nil + } + resetCount := 0 + for _, it := range items { + err := DB.Transaction(func(tx *gorm.DB) error { + var item UserSubscriptionItem + if err := tx.Set("gorm:query_option", "FOR UPDATE"). + Where("id = ? AND next_reset_time > 0 AND next_reset_time <= ?", it.Id, now). + First(&item).Error; err != nil { + return nil + } + if err := maybeResetSubscriptionItemTx(tx, &item, now); err != nil { + return err + } + resetCount++ + return nil + }) + if err != nil { + return resetCount, err + } + } + return resetCount, nil +} + +// CleanupSubscriptionPreConsumeRecords removes old idempotency records to keep table small. +func CleanupSubscriptionPreConsumeRecords(olderThanSeconds int64) (int64, error) { + if olderThanSeconds <= 0 { + olderThanSeconds = 7 * 24 * 3600 + } + cutoff := GetDBTimestamp() - olderThanSeconds + res := DB.Where("updated_at < ?", cutoff).Delete(&SubscriptionPreConsumeRecord{}) + return res.RowsAffected, res.Error +} + type SubscriptionPlanInfo struct { PlanId int PlanTitle string diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 7a3aad333..232e8c852 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -129,6 +129,8 @@ type RelayInfo struct { // SubscriptionPlanId / SubscriptionPlanTitle are used for logging/UI display. SubscriptionPlanId int SubscriptionPlanTitle string + // RequestId is used for idempotent pre-consume/refund + RequestId string // SubscriptionAmountTotal / SubscriptionAmountUsedAfterPreConsume are used to compute remaining in logs. SubscriptionAmountTotal int64 SubscriptionAmountUsedAfterPreConsume int64 @@ -418,9 +420,14 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { // firstResponseTime = time.Now() - 1 second + reqId := common.GetContextKeyString(c, common.RequestIdKey) + if reqId == "" { + reqId = common.GetTimeString() + common.GetRandomString(8) + } info := &RelayInfo{ Request: request, + RequestId: reqId, UserId: common.GetContextKeyInt(c, constant.ContextKeyUserId), UsingGroup: common.GetContextKeyString(c, constant.ContextKeyUsingGroup), UserGroup: common.GetContextKeyString(c, constant.ContextKeyUserGroup), diff --git a/service/billing.go b/service/billing.go index 84b329f72..6e001bc6d 100644 --- a/service/billing.go +++ b/service/billing.go @@ -56,7 +56,7 @@ func PreConsumeBilling(c *gin.Context, preConsumedQuota int, relayInfo *relaycom } } - res, err := model.PreConsumeUserSubscription(relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume) + res, err := model.PreConsumeUserSubscription(relayInfo.RequestId, relayInfo.UserId, relayInfo.OriginModelName, quotaType, subConsume) if err != nil { // revert token pre-consume when subscription fails if preConsumedQuota > 0 && !relayInfo.IsPlayground { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 4a5edc499..62b19bb75 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -3,6 +3,7 @@ package service import ( "fmt" "net/http" + "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/logger" @@ -30,7 +31,9 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { relayInfoCopy := *relayInfo if relayInfoCopy.BillingSource == BillingSourceSubscription { if needRefundSub { - _ = model.PostConsumeUserSubscriptionDelta(relayInfoCopy.SubscriptionItemId, -relayInfoCopy.SubscriptionPreConsumed) + refundWithRetry(func() error { + return model.RefundSubscriptionPreConsume(relayInfoCopy.RequestId) + }) } // refund token quota only if needRefundToken && !relayInfoCopy.IsPlayground { @@ -49,6 +52,21 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { }) } +func refundWithRetry(fn func() error) { + if fn == nil { + return + } + const maxAttempts = 3 + for i := 0; i < maxAttempts; i++ { + if err := fn(); err == nil { + return + } + if i < maxAttempts-1 { + time.Sleep(time.Duration(200*(i+1)) * time.Millisecond) + } + } +} + // PreConsumeQuota checks if the user has enough quota to pre-consume. // It returns the pre-consumed quota if successful, or an error if not. func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { diff --git a/service/subscription_reset_task.go b/service/subscription_reset_task.go new file mode 100644 index 000000000..630d91ef9 --- /dev/null +++ b/service/subscription_reset_task.go @@ -0,0 +1,78 @@ +package service + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/model" + + "github.com/bytedance/gopkg/util/gopool" +) + +const ( + subscriptionResetTickInterval = 1 * time.Minute + subscriptionResetBatchSize = 300 + subscriptionCleanupInterval = 30 * time.Minute +) + +var ( + subscriptionResetOnce sync.Once + subscriptionResetRunning atomic.Bool + subscriptionCleanupLast atomic.Int64 +) + +func StartSubscriptionQuotaResetTask() { + subscriptionResetOnce.Do(func() { + if !common.IsMasterNode { + return + } + gopool.Go(func() { + logger.LogInfo(context.Background(), fmt.Sprintf("subscription quota reset task started: tick=%s", subscriptionResetTickInterval)) + ticker := time.NewTicker(subscriptionResetTickInterval) + defer ticker.Stop() + + runSubscriptionQuotaResetOnce() + for range ticker.C { + runSubscriptionQuotaResetOnce() + } + }) + }) +} + +func runSubscriptionQuotaResetOnce() { + if !subscriptionResetRunning.CompareAndSwap(false, true) { + return + } + defer subscriptionResetRunning.Store(false) + + ctx := context.Background() + totalReset := 0 + for { + n, err := model.ResetDueSubscriptionItems(subscriptionResetBatchSize) + if err != nil { + logger.LogWarn(ctx, fmt.Sprintf("subscription quota reset task failed: %v", err)) + return + } + if n == 0 { + break + } + totalReset += n + if n < subscriptionResetBatchSize { + break + } + } + lastCleanup := time.Unix(subscriptionCleanupLast.Load(), 0) + if time.Since(lastCleanup) >= subscriptionCleanupInterval { + if _, err := model.CleanupSubscriptionPreConsumeRecords(7 * 24 * 3600); err == nil { + subscriptionCleanupLast.Store(time.Now().Unix()) + } + } + if totalReset > 0 && common.DebugEnabled { + logger.LogDebug(ctx, "subscription quota reset: reset_count=%d", totalReset) + } +}