feat(subscription): cache plan lookups and stabilize pre-consume

Introduce hybrid caches for subscription plans, items, and plan info with explicit
invalidation on admin updates. Streamline pre-consume transactions to reduce
redundant queries while preserving idempotency and reset logic.
This commit is contained in:
t0ng7u
2026-01-31 01:11:32 +08:00
parent ffebb35499
commit 41489fc32a
2 changed files with 293 additions and 28 deletions

View File

@@ -3,10 +3,14 @@ package model
import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/samber/hot"
"gorm.io/gorm"
)
@@ -33,6 +37,152 @@ var (
ErrSubscriptionOrderStatusInvalid = errors.New("subscription order status invalid")
)
const (
subscriptionPlanCacheNamespace = "new-api:subscription_plan:v1"
subscriptionPlanItemsCacheNamespace = "new-api:subscription_plan_items:v1"
subscriptionPlanInfoCacheNamespace = "new-api:subscription_plan_info:v1"
)
var (
subscriptionPlanCacheOnce sync.Once
subscriptionPlanItemsCacheOnce sync.Once
subscriptionPlanInfoCacheOnce sync.Once
subscriptionPlanCache *cachex.HybridCache[SubscriptionPlan]
subscriptionPlanItemsCache *cachex.HybridCache[[]SubscriptionPlanItem]
subscriptionPlanInfoCache *cachex.HybridCache[SubscriptionPlanInfo]
)
func subscriptionPlanCacheTTL() time.Duration {
ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_TTL", 300)
if ttlSeconds <= 0 {
ttlSeconds = 300
}
return time.Duration(ttlSeconds) * time.Second
}
func subscriptionPlanItemsCacheTTL() time.Duration {
ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_ITEMS_CACHE_TTL", 300)
if ttlSeconds <= 0 {
ttlSeconds = 300
}
return time.Duration(ttlSeconds) * time.Second
}
func subscriptionPlanInfoCacheTTL() time.Duration {
ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_TTL", 120)
if ttlSeconds <= 0 {
ttlSeconds = 120
}
return time.Duration(ttlSeconds) * time.Second
}
func subscriptionPlanCacheCapacity() int {
capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_CAP", 5000)
if capacity <= 0 {
capacity = 5000
}
return capacity
}
func subscriptionPlanItemsCacheCapacity() int {
capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_ITEMS_CACHE_CAP", 10000)
if capacity <= 0 {
capacity = 10000
}
return capacity
}
func subscriptionPlanInfoCacheCapacity() int {
capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_CAP", 10000)
if capacity <= 0 {
capacity = 10000
}
return capacity
}
func getSubscriptionPlanCache() *cachex.HybridCache[SubscriptionPlan] {
subscriptionPlanCacheOnce.Do(func() {
ttl := subscriptionPlanCacheTTL()
subscriptionPlanCache = cachex.NewHybridCache[SubscriptionPlan](cachex.HybridCacheConfig[SubscriptionPlan]{
Namespace: cachex.Namespace(subscriptionPlanCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.JSONCodec[SubscriptionPlan]{},
Memory: func() *hot.HotCache[string, SubscriptionPlan] {
return hot.NewHotCache[string, SubscriptionPlan](hot.LRU, subscriptionPlanCacheCapacity()).
WithTTL(ttl).
WithJanitor().
Build()
},
})
})
return subscriptionPlanCache
}
func getSubscriptionPlanItemsCache() *cachex.HybridCache[[]SubscriptionPlanItem] {
subscriptionPlanItemsCacheOnce.Do(func() {
ttl := subscriptionPlanItemsCacheTTL()
subscriptionPlanItemsCache = cachex.NewHybridCache[[]SubscriptionPlanItem](cachex.HybridCacheConfig[[]SubscriptionPlanItem]{
Namespace: cachex.Namespace(subscriptionPlanItemsCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.JSONCodec[[]SubscriptionPlanItem]{},
Memory: func() *hot.HotCache[string, []SubscriptionPlanItem] {
return hot.NewHotCache[string, []SubscriptionPlanItem](hot.LRU, subscriptionPlanItemsCacheCapacity()).
WithTTL(ttl).
WithJanitor().
Build()
},
})
})
return subscriptionPlanItemsCache
}
func getSubscriptionPlanInfoCache() *cachex.HybridCache[SubscriptionPlanInfo] {
subscriptionPlanInfoCacheOnce.Do(func() {
ttl := subscriptionPlanInfoCacheTTL()
subscriptionPlanInfoCache = cachex.NewHybridCache[SubscriptionPlanInfo](cachex.HybridCacheConfig[SubscriptionPlanInfo]{
Namespace: cachex.Namespace(subscriptionPlanInfoCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.JSONCodec[SubscriptionPlanInfo]{},
Memory: func() *hot.HotCache[string, SubscriptionPlanInfo] {
return hot.NewHotCache[string, SubscriptionPlanInfo](hot.LRU, subscriptionPlanInfoCacheCapacity()).
WithTTL(ttl).
WithJanitor().
Build()
},
})
})
return subscriptionPlanInfoCache
}
func subscriptionPlanCacheKey(id int) string {
if id <= 0 {
return ""
}
return strconv.Itoa(id)
}
func InvalidateSubscriptionPlanCache(planId int) {
if planId <= 0 {
return
}
cache := getSubscriptionPlanCache()
_, _ = cache.DeleteMany([]string{subscriptionPlanCacheKey(planId)})
itemsCache := getSubscriptionPlanItemsCache()
_, _ = itemsCache.DeleteMany([]string{subscriptionPlanCacheKey(planId)})
infoCache := getSubscriptionPlanInfoCache()
_ = infoCache.Purge()
}
// Subscription plan
type SubscriptionPlan struct {
Id int `json:"id"`
@@ -245,13 +395,28 @@ func calcNextResetTime(base time.Time, plan *SubscriptionPlan, endUnix int64) in
}
func GetSubscriptionPlanById(id int) (*SubscriptionPlan, error) {
return getSubscriptionPlanByIdTx(nil, id)
}
func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
if id <= 0 {
return nil, errors.New("invalid plan id")
}
key := subscriptionPlanCacheKey(id)
if key != "" {
if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
return &cached, nil
}
}
var plan SubscriptionPlan
if err := DB.Where("id = ?", id).First(&plan).Error; err != nil {
query := DB
if tx != nil {
query = tx
}
if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
return nil, err
}
_ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
return &plan, nil
}
@@ -259,10 +424,17 @@ func GetSubscriptionPlanItems(planId int) ([]SubscriptionPlanItem, error) {
if planId <= 0 {
return nil, errors.New("invalid plan id")
}
key := subscriptionPlanCacheKey(planId)
if key != "" {
if cached, found, err := getSubscriptionPlanItemsCache().Get(key); err == nil && found {
return cached, nil
}
}
var items []SubscriptionPlanItem
if err := DB.Where("plan_id = ?", planId).Find(&items).Error; err != nil {
return nil, err
}
_ = getSubscriptionPlanItemsCache().SetWithTTL(key, items, subscriptionPlanItemsCacheTTL())
return items, nil
}
@@ -600,10 +772,20 @@ func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now i
if err := tx.Where("id = ?", item.UserSubscriptionId).First(&sub).Error; err != nil {
return err
}
var plan SubscriptionPlan
if err := tx.Where("id = ?", sub.PlanId).First(&plan).Error; err != nil {
plan, err := getSubscriptionPlanByIdTx(tx, sub.PlanId)
if err != nil {
return err
}
return maybeResetSubscriptionItemWithPlanTx(tx, item, &sub, plan, now)
}
func maybeResetSubscriptionItemWithPlanTx(tx *gorm.DB, item *UserSubscriptionItem, sub *UserSubscription, plan *SubscriptionPlan, now int64) error {
if tx == nil || item == nil || sub == nil || plan == nil {
return errors.New("invalid reset args")
}
if item.NextResetTime > 0 && item.NextResetTime > now {
return nil
}
if normalizeResetPeriod(plan.QuotaResetPeriod) == SubscriptionResetNever {
return nil
}
@@ -613,12 +795,12 @@ func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now i
baseUnix = sub.StartTime
}
base := time.Unix(baseUnix, 0)
next := calcNextResetTime(base, &plan, sub.EndTime)
next := calcNextResetTime(base, plan, sub.EndTime)
advanced := false
for next > 0 && next <= now {
advanced = true
base = time.Unix(next, 0)
next = calcNextResetTime(base, &plan, sub.EndTime)
next = calcNextResetTime(base, plan, sub.EndTime)
}
if !advanced {
// keep next reset time in sync if missing
@@ -653,10 +835,10 @@ func PreConsumeUserSubscription(requestId string, userId int, modelName string,
now := GetDBTimestamp()
returnValue := &SubscriptionPreConsumeResult{}
err := DB.Transaction(func(tx *gorm.DB) error {
var existing SubscriptionPreConsumeRecord
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("request_id = ?", requestId).First(&existing).Error; err == nil {
if err := tx.Where("request_id = ?", requestId).First(&existing).Error; err == nil {
if existing.Status == "refunded" {
return errors.New("subscription pre-consume already refunded")
}
@@ -674,20 +856,35 @@ func PreConsumeUserSubscription(requestId string, userId int, modelName string,
return nil
}
var item UserSubscriptionItem
// lock item row; join to ensure subscription still active
q := tx.Set("gorm:query_option", "FOR UPDATE").
Table("user_subscription_items").
Select("user_subscription_items.*").
Joins("JOIN user_subscriptions ON user_subscriptions.id = user_subscription_items.user_subscription_id").
Where("user_subscriptions.user_id = ? AND user_subscriptions.status = ? AND user_subscriptions.end_time > ?", userId, "active", now).
Where("user_subscription_items.model_name = ? AND user_subscription_items.quota_type = ?", modelName, quotaType).
Order("user_subscriptions.end_time desc, user_subscriptions.id desc, user_subscription_items.id desc").
Limit(1)
if err := q.First(&item).Error; err != nil {
var activeSub UserSubscription
if err := tx.Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
Order("end_time desc, id desc").
First(&activeSub).Error; err != nil {
return errors.New("no active subscription item for this model")
}
if err := maybeResetSubscriptionItemTx(tx, &item, now); err != nil {
var candidate UserSubscriptionItem
if err := tx.Where("user_subscription_id = ? AND model_name = ? AND quota_type = ?", activeSub.Id, modelName, quotaType).
Order("id desc").
First(&candidate).Error; err != nil {
return errors.New("no active subscription item for this model")
}
var item UserSubscriptionItem
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("id = ?", candidate.Id).
First(&item).Error; err != nil {
return errors.New("no active subscription item for this model")
}
var sub UserSubscription
if err := tx.Where("id = ? AND user_id = ? AND status = ? AND end_time > ?", item.UserSubscriptionId, userId, "active", now).
First(&sub).Error; err != nil {
return errors.New("no active subscription item for this model")
}
plan, err := getSubscriptionPlanByIdTx(tx, sub.PlanId)
if err != nil {
return err
}
if err := maybeResetSubscriptionItemWithPlanTx(tx, &item, &sub, plan, now); err != nil {
return err
}
usedBefore := item.AmountUsed
@@ -695,10 +892,6 @@ func PreConsumeUserSubscription(requestId string, userId int, modelName string,
if remain < amount {
return fmt.Errorf("subscription quota insufficient, remain=%d need=%d", remain, amount)
}
item.AmountUsed += amount
if err := tx.Save(&item).Error; err != nil {
return err
}
record := &SubscriptionPreConsumeRecord{
RequestId: requestId,
UserId: userId,
@@ -707,6 +900,24 @@ func PreConsumeUserSubscription(requestId string, userId int, modelName string,
Status: "consumed",
}
if err := tx.Create(record).Error; err != nil {
var dup SubscriptionPreConsumeRecord
if err2 := tx.Where("request_id = ?", requestId).First(&dup).Error; err2 == nil {
if dup.Status == "refunded" {
return errors.New("subscription pre-consume already refunded")
}
returnValue.UserSubscriptionId = item.UserSubscriptionId
returnValue.ItemId = item.Id
returnValue.QuotaType = item.QuotaType
returnValue.PreConsumed = dup.PreConsumed
returnValue.AmountTotal = item.AmountTotal
returnValue.AmountUsedBefore = item.AmountUsed
returnValue.AmountUsedAfter = item.AmountUsed
return nil
}
return err
}
item.AmountUsed += amount
if err := tx.Save(&item).Error; err != nil {
return err
}
returnValue.UserSubscriptionId = item.UserSubscriptionId
@@ -766,8 +977,53 @@ func ResetDueSubscriptionItems(limit int) (int, error) {
if len(items) == 0 {
return 0, nil
}
subIds := make([]int, 0, len(items))
subIdSet := make(map[int]struct{}, len(items))
for _, it := range items {
if it.UserSubscriptionId <= 0 {
continue
}
if _, exists := subIdSet[it.UserSubscriptionId]; exists {
continue
}
subIdSet[it.UserSubscriptionId] = struct{}{}
subIds = append(subIds, it.UserSubscriptionId)
}
subById := make(map[int]*UserSubscription, len(subIds))
if len(subIds) > 0 {
var subs []UserSubscription
if err := DB.Where("id IN ?", subIds).Find(&subs).Error; err != nil {
return 0, err
}
for i := range subs {
sub := subs[i]
subById[sub.Id] = &sub
}
}
planById := make(map[int]*SubscriptionPlan, len(subById))
for _, sub := range subById {
if sub == nil || sub.PlanId <= 0 {
continue
}
if _, exists := planById[sub.PlanId]; exists {
continue
}
plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
if err != nil {
return 0, err
}
planById[sub.PlanId] = plan
}
resetCount := 0
for _, it := range items {
sub := subById[it.UserSubscriptionId]
if sub == nil {
continue
}
plan := planById[sub.PlanId]
if plan == nil {
continue
}
err := DB.Transaction(func(tx *gorm.DB) error {
var item UserSubscriptionItem
if err := tx.Set("gorm:query_option", "FOR UPDATE").
@@ -775,7 +1031,7 @@ func ResetDueSubscriptionItems(limit int) (int, error) {
First(&item).Error; err != nil {
return nil
}
if err := maybeResetSubscriptionItemTx(tx, &item, now); err != nil {
if err := maybeResetSubscriptionItemWithPlanTx(tx, &item, sub, plan, now); err != nil {
return err
}
resetCount++
@@ -807,18 +1063,24 @@ func GetSubscriptionPlanInfoByUserSubscriptionId(userSubscriptionId int) (*Subsc
if userSubscriptionId <= 0 {
return nil, errors.New("invalid userSubscriptionId")
}
cacheKey := fmt.Sprintf("sub:%d", userSubscriptionId)
if cached, found, err := getSubscriptionPlanInfoCache().Get(cacheKey); err == nil && found {
return &cached, nil
}
var sub UserSubscription
if err := DB.Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
return nil, err
}
var plan SubscriptionPlan
if err := DB.Where("id = ?", sub.PlanId).First(&plan).Error; err != nil {
plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
if err != nil {
return nil, err
}
return &SubscriptionPlanInfo{
info := &SubscriptionPlanInfo{
PlanId: sub.PlanId,
PlanTitle: plan.Title,
}, nil
}
_ = getSubscriptionPlanInfoCache().SetWithTTL(cacheKey, *info, subscriptionPlanInfoCacheTTL())
return info, nil
}
func GetSubscriptionPlanInfoBySubscriptionItemId(itemId int) (*SubscriptionPlanInfo, error) {