Files
new-api/model/subscription.go
t0ng7u cf67af3b14 feat: Add subscription limits and UI tags consistency
Add per-plan purchase limits with backend enforcement and UI disable states.
Expose limit configuration in admin plan editor and show limits in plan tables/cards.
Refine subscription UI tags with unified badge style and streamlined “My Subscriptions” layout.
2026-01-31 15:02:03 +08:00

1145 lines
35 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package model
import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/pkg/cachex"
"github.com/samber/hot"
"gorm.io/gorm"
)
// Subscription duration units
const (
SubscriptionDurationYear = "year"
SubscriptionDurationMonth = "month"
SubscriptionDurationDay = "day"
SubscriptionDurationHour = "hour"
SubscriptionDurationCustom = "custom"
)
// Subscription quota reset period
const (
SubscriptionResetNever = "never"
SubscriptionResetDaily = "daily"
SubscriptionResetWeekly = "weekly"
SubscriptionResetMonthly = "monthly"
SubscriptionResetCustom = "custom"
)
var (
ErrSubscriptionOrderNotFound = errors.New("subscription order not found")
ErrSubscriptionOrderStatusInvalid = errors.New("subscription order status invalid")
)
const (
subscriptionPlanCacheNamespace = "new-api:subscription_plan:v1"
subscriptionPlanItemsCacheNamespace = "new-api:subscription_plan_items:v1"
subscriptionPlanInfoCacheNamespace = "new-api:subscription_plan_info:v1"
)
var (
subscriptionPlanCacheOnce sync.Once
subscriptionPlanItemsCacheOnce sync.Once
subscriptionPlanInfoCacheOnce sync.Once
subscriptionPlanCache *cachex.HybridCache[SubscriptionPlan]
subscriptionPlanItemsCache *cachex.HybridCache[[]SubscriptionPlanItem]
subscriptionPlanInfoCache *cachex.HybridCache[SubscriptionPlanInfo]
)
func subscriptionPlanCacheTTL() time.Duration {
ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_TTL", 300)
if ttlSeconds <= 0 {
ttlSeconds = 300
}
return time.Duration(ttlSeconds) * time.Second
}
func subscriptionPlanItemsCacheTTL() time.Duration {
ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_ITEMS_CACHE_TTL", 300)
if ttlSeconds <= 0 {
ttlSeconds = 300
}
return time.Duration(ttlSeconds) * time.Second
}
func subscriptionPlanInfoCacheTTL() time.Duration {
ttlSeconds := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_TTL", 120)
if ttlSeconds <= 0 {
ttlSeconds = 120
}
return time.Duration(ttlSeconds) * time.Second
}
func subscriptionPlanCacheCapacity() int {
capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_CACHE_CAP", 5000)
if capacity <= 0 {
capacity = 5000
}
return capacity
}
func subscriptionPlanItemsCacheCapacity() int {
capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_ITEMS_CACHE_CAP", 10000)
if capacity <= 0 {
capacity = 10000
}
return capacity
}
func subscriptionPlanInfoCacheCapacity() int {
capacity := common.GetEnvOrDefault("SUBSCRIPTION_PLAN_INFO_CACHE_CAP", 10000)
if capacity <= 0 {
capacity = 10000
}
return capacity
}
func getSubscriptionPlanCache() *cachex.HybridCache[SubscriptionPlan] {
subscriptionPlanCacheOnce.Do(func() {
ttl := subscriptionPlanCacheTTL()
subscriptionPlanCache = cachex.NewHybridCache[SubscriptionPlan](cachex.HybridCacheConfig[SubscriptionPlan]{
Namespace: cachex.Namespace(subscriptionPlanCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.JSONCodec[SubscriptionPlan]{},
Memory: func() *hot.HotCache[string, SubscriptionPlan] {
return hot.NewHotCache[string, SubscriptionPlan](hot.LRU, subscriptionPlanCacheCapacity()).
WithTTL(ttl).
WithJanitor().
Build()
},
})
})
return subscriptionPlanCache
}
func getSubscriptionPlanItemsCache() *cachex.HybridCache[[]SubscriptionPlanItem] {
subscriptionPlanItemsCacheOnce.Do(func() {
ttl := subscriptionPlanItemsCacheTTL()
subscriptionPlanItemsCache = cachex.NewHybridCache[[]SubscriptionPlanItem](cachex.HybridCacheConfig[[]SubscriptionPlanItem]{
Namespace: cachex.Namespace(subscriptionPlanItemsCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.JSONCodec[[]SubscriptionPlanItem]{},
Memory: func() *hot.HotCache[string, []SubscriptionPlanItem] {
return hot.NewHotCache[string, []SubscriptionPlanItem](hot.LRU, subscriptionPlanItemsCacheCapacity()).
WithTTL(ttl).
WithJanitor().
Build()
},
})
})
return subscriptionPlanItemsCache
}
func getSubscriptionPlanInfoCache() *cachex.HybridCache[SubscriptionPlanInfo] {
subscriptionPlanInfoCacheOnce.Do(func() {
ttl := subscriptionPlanInfoCacheTTL()
subscriptionPlanInfoCache = cachex.NewHybridCache[SubscriptionPlanInfo](cachex.HybridCacheConfig[SubscriptionPlanInfo]{
Namespace: cachex.Namespace(subscriptionPlanInfoCacheNamespace),
Redis: common.RDB,
RedisEnabled: func() bool {
return common.RedisEnabled && common.RDB != nil
},
RedisCodec: cachex.JSONCodec[SubscriptionPlanInfo]{},
Memory: func() *hot.HotCache[string, SubscriptionPlanInfo] {
return hot.NewHotCache[string, SubscriptionPlanInfo](hot.LRU, subscriptionPlanInfoCacheCapacity()).
WithTTL(ttl).
WithJanitor().
Build()
},
})
})
return subscriptionPlanInfoCache
}
func subscriptionPlanCacheKey(id int) string {
if id <= 0 {
return ""
}
return strconv.Itoa(id)
}
func InvalidateSubscriptionPlanCache(planId int) {
if planId <= 0 {
return
}
cache := getSubscriptionPlanCache()
_, _ = cache.DeleteMany([]string{subscriptionPlanCacheKey(planId)})
itemsCache := getSubscriptionPlanItemsCache()
_, _ = itemsCache.DeleteMany([]string{subscriptionPlanCacheKey(planId)})
infoCache := getSubscriptionPlanInfoCache()
_ = infoCache.Purge()
}
// Subscription plan
type SubscriptionPlan struct {
Id int `json:"id"`
Title string `json:"title" gorm:"type:varchar(128);not null"`
Subtitle string `json:"subtitle" gorm:"type:varchar(255);default:''"`
// Display money amount (follow existing code style: float64 for money)
PriceAmount float64 `json:"price_amount" gorm:"type:double;not null;default:0"`
Currency string `json:"currency" gorm:"type:varchar(8);not null;default:'USD'"`
DurationUnit string `json:"duration_unit" gorm:"type:varchar(16);not null;default:'month'"`
DurationValue int `json:"duration_value" gorm:"type:int;not null;default:1"`
CustomSeconds int64 `json:"custom_seconds" gorm:"type:bigint;not null;default:0"`
Enabled bool `json:"enabled" gorm:"default:true"`
SortOrder int `json:"sort_order" gorm:"type:int;default:0"`
StripePriceId string `json:"stripe_price_id" gorm:"type:varchar(128);default:''"`
CreemProductId string `json:"creem_product_id" gorm:"type:varchar(128);default:''"`
// Max purchases per user (0 = unlimited)
MaxPurchasePerUser int `json:"max_purchase_per_user" gorm:"type:int;default:0"`
// Quota reset period for plan items
QuotaResetPeriod string `json:"quota_reset_period" gorm:"type:varchar(16);default:'never'"`
QuotaResetCustomSeconds int64 `json:"quota_reset_custom_seconds" gorm:"type:bigint;default:0"`
CreatedAt int64 `json:"created_at" gorm:"bigint"`
UpdatedAt int64 `json:"updated_at" gorm:"bigint"`
}
func (p *SubscriptionPlan) BeforeCreate(tx *gorm.DB) error {
now := common.GetTimestamp()
p.CreatedAt = now
p.UpdatedAt = now
return nil
}
func (p *SubscriptionPlan) BeforeUpdate(tx *gorm.DB) error {
p.UpdatedAt = common.GetTimestamp()
return nil
}
type SubscriptionPlanItem struct {
Id int `json:"id"`
PlanId int `json:"plan_id" gorm:"index"`
ModelName string `json:"model_name" gorm:"type:varchar(128);index"`
// 0=按量(额度), 1=按次(次数)
QuotaType int `json:"quota_type" gorm:"type:int;index"`
// If quota_type=0 => amount in quota units; if quota_type=1 => request count.
AmountTotal int64 `json:"amount_total" gorm:"type:bigint;not null;default:0"`
}
// Subscription order (payment -> webhook -> create UserSubscription)
type SubscriptionOrder struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"`
PlanId int `json:"plan_id" gorm:"index"`
Money float64 `json:"money"`
TradeNo string `json:"trade_no" gorm:"unique;type:varchar(255);index"`
PaymentMethod string `json:"payment_method" gorm:"type:varchar(50)"`
Status string `json:"status"`
CreateTime int64 `json:"create_time"`
CompleteTime int64 `json:"complete_time"`
ProviderPayload string `json:"provider_payload" gorm:"type:text"`
}
func (o *SubscriptionOrder) Insert() error {
if o.CreateTime == 0 {
o.CreateTime = common.GetTimestamp()
}
return DB.Create(o).Error
}
func (o *SubscriptionOrder) Update() error {
return DB.Save(o).Error
}
func GetSubscriptionOrderByTradeNo(tradeNo string) *SubscriptionOrder {
if tradeNo == "" {
return nil
}
var order SubscriptionOrder
if err := DB.Where("trade_no = ?", tradeNo).First(&order).Error; err != nil {
return nil
}
return &order
}
// User subscription instance
type UserSubscription struct {
Id int `json:"id"`
UserId int `json:"user_id" gorm:"index;index:idx_user_sub_active,priority:1"`
PlanId int `json:"plan_id" gorm:"index"`
StartTime int64 `json:"start_time" gorm:"bigint"`
EndTime int64 `json:"end_time" gorm:"bigint;index;index:idx_user_sub_active,priority:3"`
Status string `json:"status" gorm:"type:varchar(32);index;index:idx_user_sub_active,priority:2"` // active/expired/cancelled
Source string `json:"source" gorm:"type:varchar(32);default:'order'"` // order/admin
CreatedAt int64 `json:"created_at" gorm:"bigint"`
UpdatedAt int64 `json:"updated_at" gorm:"bigint"`
}
func (s *UserSubscription) BeforeCreate(tx *gorm.DB) error {
now := common.GetTimestamp()
s.CreatedAt = now
s.UpdatedAt = now
return nil
}
func (s *UserSubscription) BeforeUpdate(tx *gorm.DB) error {
s.UpdatedAt = common.GetTimestamp()
return nil
}
type UserSubscriptionItem struct {
Id int `json:"id"`
UserSubscriptionId int `json:"user_subscription_id" gorm:"index;index:idx_sub_item_model_quota,priority:3"`
ModelName string `json:"model_name" gorm:"type:varchar(128);index;index:idx_sub_item_model_quota,priority:1"`
QuotaType int `json:"quota_type" gorm:"type:int;index;index:idx_sub_item_model_quota,priority:2"`
AmountTotal int64 `json:"amount_total" gorm:"type:bigint;not null;default:0"`
AmountUsed int64 `json:"amount_used" gorm:"type:bigint;not null;default:0"`
LastResetTime int64 `json:"last_reset_time" gorm:"type:bigint;default:0"`
NextResetTime int64 `json:"next_reset_time" gorm:"type:bigint;default:0;index"`
}
type SubscriptionSummary struct {
Subscription *UserSubscription `json:"subscription"`
Items []UserSubscriptionItem `json:"items"`
}
func calcPlanEndTime(start time.Time, plan *SubscriptionPlan) (int64, error) {
if plan == nil {
return 0, errors.New("plan is nil")
}
if plan.DurationValue <= 0 && plan.DurationUnit != SubscriptionDurationCustom {
return 0, errors.New("duration_value must be > 0")
}
switch plan.DurationUnit {
case SubscriptionDurationYear:
return start.AddDate(plan.DurationValue, 0, 0).Unix(), nil
case SubscriptionDurationMonth:
return start.AddDate(0, plan.DurationValue, 0).Unix(), nil
case SubscriptionDurationDay:
return start.Add(time.Duration(plan.DurationValue) * 24 * time.Hour).Unix(), nil
case SubscriptionDurationHour:
return start.Add(time.Duration(plan.DurationValue) * time.Hour).Unix(), nil
case SubscriptionDurationCustom:
if plan.CustomSeconds <= 0 {
return 0, errors.New("custom_seconds must be > 0")
}
return start.Add(time.Duration(plan.CustomSeconds) * time.Second).Unix(), nil
default:
return 0, fmt.Errorf("invalid duration_unit: %s", plan.DurationUnit)
}
}
func NormalizeResetPeriod(period string) string {
switch strings.TrimSpace(period) {
case SubscriptionResetDaily, SubscriptionResetWeekly, SubscriptionResetMonthly, SubscriptionResetCustom:
return strings.TrimSpace(period)
default:
return SubscriptionResetNever
}
}
func calcNextResetTime(base time.Time, plan *SubscriptionPlan, endUnix int64) int64 {
if plan == nil {
return 0
}
period := NormalizeResetPeriod(plan.QuotaResetPeriod)
if period == SubscriptionResetNever {
return 0
}
var next time.Time
switch period {
case SubscriptionResetDaily:
next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
AddDate(0, 0, 1)
case SubscriptionResetWeekly:
// Align to next Monday 00:00
weekday := int(base.Weekday()) // Sunday=0
// Convert to Monday=1..Sunday=7
if weekday == 0 {
weekday = 7
}
daysUntil := 8 - weekday
next = time.Date(base.Year(), base.Month(), base.Day(), 0, 0, 0, 0, base.Location()).
AddDate(0, 0, daysUntil)
case SubscriptionResetMonthly:
// Align to first day of next month 00:00
next = time.Date(base.Year(), base.Month(), 1, 0, 0, 0, 0, base.Location()).
AddDate(0, 1, 0)
case SubscriptionResetCustom:
if plan.QuotaResetCustomSeconds <= 0 {
return 0
}
next = base.Add(time.Duration(plan.QuotaResetCustomSeconds) * time.Second)
default:
return 0
}
if endUnix > 0 && next.Unix() > endUnix {
return 0
}
return next.Unix()
}
func GetSubscriptionPlanById(id int) (*SubscriptionPlan, error) {
return getSubscriptionPlanByIdTx(nil, id)
}
func getSubscriptionPlanByIdTx(tx *gorm.DB, id int) (*SubscriptionPlan, error) {
if id <= 0 {
return nil, errors.New("invalid plan id")
}
key := subscriptionPlanCacheKey(id)
if key != "" {
if cached, found, err := getSubscriptionPlanCache().Get(key); err == nil && found {
return &cached, nil
}
}
var plan SubscriptionPlan
query := DB
if tx != nil {
query = tx
}
if err := query.Where("id = ?", id).First(&plan).Error; err != nil {
return nil, err
}
_ = getSubscriptionPlanCache().SetWithTTL(key, plan, subscriptionPlanCacheTTL())
return &plan, nil
}
func GetSubscriptionPlanItems(planId int) ([]SubscriptionPlanItem, error) {
if planId <= 0 {
return nil, errors.New("invalid plan id")
}
key := subscriptionPlanCacheKey(planId)
if key != "" {
if cached, found, err := getSubscriptionPlanItemsCache().Get(key); err == nil && found {
return cached, nil
}
}
var items []SubscriptionPlanItem
if err := DB.Where("plan_id = ?", planId).Find(&items).Error; err != nil {
return nil, err
}
_ = getSubscriptionPlanItemsCache().SetWithTTL(key, items, subscriptionPlanItemsCacheTTL())
return items, nil
}
func CountUserSubscriptionsByPlan(userId int, planId int) (int64, error) {
if userId <= 0 || planId <= 0 {
return 0, errors.New("invalid userId or planId")
}
var count int64
if err := DB.Model(&UserSubscription{}).
Where("user_id = ? AND plan_id = ?", userId, planId).
Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
func CreateUserSubscriptionFromPlanTx(tx *gorm.DB, userId int, plan *SubscriptionPlan, source string) (*UserSubscription, error) {
if tx == nil {
return nil, errors.New("tx is nil")
}
if plan == nil || plan.Id == 0 {
return nil, errors.New("invalid plan")
}
if userId <= 0 {
return nil, errors.New("invalid user id")
}
if plan.MaxPurchasePerUser > 0 {
var count int64
if err := tx.Model(&UserSubscription{}).
Where("user_id = ? AND plan_id = ?", userId, plan.Id).
Count(&count).Error; err != nil {
return nil, err
}
if count >= int64(plan.MaxPurchasePerUser) {
return nil, errors.New("已达到该套餐购买上限")
}
}
nowUnix := GetDBTimestamp()
now := time.Unix(nowUnix, 0)
endUnix, err := calcPlanEndTime(now, plan)
if err != nil {
return nil, err
}
resetBase := now
nextReset := calcNextResetTime(resetBase, plan, endUnix)
lastReset := int64(0)
if nextReset > 0 {
lastReset = now.Unix()
}
sub := &UserSubscription{
UserId: userId,
PlanId: plan.Id,
StartTime: now.Unix(),
EndTime: endUnix,
Status: "active",
Source: source,
CreatedAt: common.GetTimestamp(),
UpdatedAt: common.GetTimestamp(),
}
if err := tx.Create(sub).Error; err != nil {
return nil, err
}
items, err := GetSubscriptionPlanItems(plan.Id)
if err != nil {
return nil, err
}
if len(items) == 0 {
return nil, errors.New("plan has no items")
}
userItems := make([]UserSubscriptionItem, 0, len(items))
for _, it := range items {
userItems = append(userItems, UserSubscriptionItem{
UserSubscriptionId: sub.Id,
ModelName: it.ModelName,
QuotaType: it.QuotaType,
AmountTotal: it.AmountTotal,
AmountUsed: 0,
LastResetTime: lastReset,
NextResetTime: nextReset,
})
}
if err := tx.Create(&userItems).Error; err != nil {
return nil, err
}
return sub, nil
}
// Complete a subscription order (idempotent). Creates a UserSubscription snapshot from the plan.
func CompleteSubscriptionOrder(tradeNo string, providerPayload string) error {
if tradeNo == "" {
return errors.New("tradeNo is empty")
}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
var logUserId int
var logPlanTitle string
var logMoney float64
var logPaymentMethod string
err := DB.Transaction(func(tx *gorm.DB) error {
var order SubscriptionOrder
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
return ErrSubscriptionOrderNotFound
}
if order.Status == common.TopUpStatusSuccess {
return nil
}
if order.Status != common.TopUpStatusPending {
return ErrSubscriptionOrderStatusInvalid
}
plan, err := GetSubscriptionPlanById(order.PlanId)
if err != nil {
return err
}
if !plan.Enabled {
// still allow completion for already purchased orders
}
_, err = CreateUserSubscriptionFromPlanTx(tx, order.UserId, plan, "order")
if err != nil {
return err
}
if err := upsertSubscriptionTopUpTx(tx, &order); err != nil {
return err
}
order.Status = common.TopUpStatusSuccess
order.CompleteTime = common.GetTimestamp()
if providerPayload != "" {
order.ProviderPayload = providerPayload
}
if err := tx.Save(&order).Error; err != nil {
return err
}
logUserId = order.UserId
logPlanTitle = plan.Title
logMoney = order.Money
logPaymentMethod = order.PaymentMethod
return nil
})
if err != nil {
return err
}
if logUserId > 0 {
msg := fmt.Sprintf("订阅购买成功,套餐: %s支付金额: %.2f,支付方式: %s", logPlanTitle, logMoney, logPaymentMethod)
RecordLog(logUserId, LogTypeTopup, msg)
}
return nil
}
func upsertSubscriptionTopUpTx(tx *gorm.DB, order *SubscriptionOrder) error {
if tx == nil || order == nil {
return errors.New("invalid subscription order")
}
now := common.GetTimestamp()
var topup TopUp
if err := tx.Where("trade_no = ?", order.TradeNo).First(&topup).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
topup = TopUp{
UserId: order.UserId,
Amount: 0,
Money: order.Money,
TradeNo: order.TradeNo,
PaymentMethod: order.PaymentMethod,
CreateTime: order.CreateTime,
CompleteTime: now,
Status: common.TopUpStatusSuccess,
}
return tx.Create(&topup).Error
}
return err
}
topup.Money = order.Money
if topup.PaymentMethod == "" {
topup.PaymentMethod = order.PaymentMethod
}
if topup.CreateTime == 0 {
topup.CreateTime = order.CreateTime
}
topup.CompleteTime = now
topup.Status = common.TopUpStatusSuccess
return tx.Save(&topup).Error
}
func ExpireSubscriptionOrder(tradeNo string) error {
if tradeNo == "" {
return errors.New("tradeNo is empty")
}
refCol := "`trade_no`"
if common.UsingPostgreSQL {
refCol = `"trade_no"`
}
return DB.Transaction(func(tx *gorm.DB) error {
var order SubscriptionOrder
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where(refCol+" = ?", tradeNo).First(&order).Error; err != nil {
return ErrSubscriptionOrderNotFound
}
if order.Status != common.TopUpStatusPending {
return nil
}
order.Status = common.TopUpStatusExpired
order.CompleteTime = common.GetTimestamp()
return tx.Save(&order).Error
})
}
// Admin bind (no payment). Creates a UserSubscription from a plan.
func AdminBindSubscription(userId int, planId int, sourceNote string) error {
if userId <= 0 || planId <= 0 {
return errors.New("invalid userId or planId")
}
plan, err := GetSubscriptionPlanById(planId)
if err != nil {
return err
}
return DB.Transaction(func(tx *gorm.DB) error {
_, err := CreateUserSubscriptionFromPlanTx(tx, userId, plan, "admin")
return err
})
}
// GetAllActiveUserSubscriptions returns all active subscriptions for a user.
func GetAllActiveUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
if userId <= 0 {
return nil, errors.New("invalid userId")
}
now := common.GetTimestamp()
var subs []UserSubscription
err := DB.Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
Order("end_time desc, id desc").
Find(&subs).Error
if err != nil {
return nil, err
}
return buildSubscriptionSummaries(subs)
}
// GetAllUserSubscriptions returns all subscriptions (active and expired) for a user.
func GetAllUserSubscriptions(userId int) ([]SubscriptionSummary, error) {
if userId <= 0 {
return nil, errors.New("invalid userId")
}
var subs []UserSubscription
err := DB.Where("user_id = ?", userId).
Order("end_time desc, id desc").
Find(&subs).Error
if err != nil {
return nil, err
}
return buildSubscriptionSummaries(subs)
}
func buildSubscriptionSummaries(subs []UserSubscription) ([]SubscriptionSummary, error) {
if len(subs) == 0 {
return []SubscriptionSummary{}, nil
}
subIds := make([]int, 0, len(subs))
for _, sub := range subs {
subIds = append(subIds, sub.Id)
}
var items []UserSubscriptionItem
if err := DB.Where("user_subscription_id IN ?", subIds).Find(&items).Error; err != nil {
return nil, err
}
itemsMap := make(map[int][]UserSubscriptionItem, len(subIds))
for _, it := range items {
itemsMap[it.UserSubscriptionId] = append(itemsMap[it.UserSubscriptionId], it)
}
result := make([]SubscriptionSummary, 0, len(subs))
for _, sub := range subs {
subCopy := sub
result = append(result, SubscriptionSummary{
Subscription: &subCopy,
Items: itemsMap[sub.Id],
})
}
return result, nil
}
// AdminInvalidateUserSubscription marks a user subscription as cancelled and ends it immediately.
func AdminInvalidateUserSubscription(userSubscriptionId int) error {
if userSubscriptionId <= 0 {
return errors.New("invalid userSubscriptionId")
}
now := common.GetTimestamp()
return DB.Model(&UserSubscription{}).
Where("id = ?", userSubscriptionId).
Updates(map[string]interface{}{
"status": "cancelled",
"end_time": now,
"updated_at": now,
}).Error
}
// AdminDeleteUserSubscription hard-deletes a user subscription and its items.
func AdminDeleteUserSubscription(userSubscriptionId int) error {
if userSubscriptionId <= 0 {
return errors.New("invalid userSubscriptionId")
}
return DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("user_subscription_id = ?", userSubscriptionId).Delete(&UserSubscriptionItem{}).Error; err != nil {
return err
}
if err := tx.Where("id = ?", userSubscriptionId).Delete(&UserSubscription{}).Error; err != nil {
return err
}
return nil
})
}
type SubscriptionPreConsumeResult struct {
UserSubscriptionId int
ItemId int
QuotaType int
PreConsumed int64
AmountTotal int64
AmountUsedBefore int64
AmountUsedAfter int64
}
// SubscriptionPreConsumeRecord stores idempotent pre-consume operations per request.
type SubscriptionPreConsumeRecord struct {
Id int `json:"id"`
RequestId string `json:"request_id" gorm:"type:varchar(64);uniqueIndex"`
UserId int `json:"user_id" gorm:"index"`
UserSubscriptionItemId int `json:"user_subscription_item_id" gorm:"index"`
PreConsumed int64 `json:"pre_consumed" gorm:"type:bigint;not null;default:0"`
Status string `json:"status" gorm:"type:varchar(32);index"` // consumed/refunded
CreatedAt int64 `json:"created_at" gorm:"bigint"`
UpdatedAt int64 `json:"updated_at" gorm:"bigint;index"`
}
func (r *SubscriptionPreConsumeRecord) BeforeCreate(tx *gorm.DB) error {
now := common.GetTimestamp()
r.CreatedAt = now
r.UpdatedAt = now
return nil
}
func (r *SubscriptionPreConsumeRecord) BeforeUpdate(tx *gorm.DB) error {
r.UpdatedAt = common.GetTimestamp()
return nil
}
func maybeResetSubscriptionItemTx(tx *gorm.DB, item *UserSubscriptionItem, now int64) error {
if tx == nil || item == nil {
return errors.New("invalid reset args")
}
if item.NextResetTime > 0 && item.NextResetTime > now {
return nil
}
var sub UserSubscription
if err := tx.Where("id = ?", item.UserSubscriptionId).First(&sub).Error; err != nil {
return err
}
plan, err := getSubscriptionPlanByIdTx(tx, sub.PlanId)
if err != nil {
return err
}
return maybeResetSubscriptionItemWithPlanTx(tx, item, &sub, plan, now)
}
func maybeResetSubscriptionItemWithPlanTx(tx *gorm.DB, item *UserSubscriptionItem, sub *UserSubscription, plan *SubscriptionPlan, now int64) error {
if tx == nil || item == nil || sub == nil || plan == nil {
return errors.New("invalid reset args")
}
if item.NextResetTime > 0 && item.NextResetTime > now {
return nil
}
if NormalizeResetPeriod(plan.QuotaResetPeriod) == SubscriptionResetNever {
return nil
}
baseUnix := item.LastResetTime
if baseUnix <= 0 {
baseUnix = sub.StartTime
}
base := time.Unix(baseUnix, 0)
next := calcNextResetTime(base, plan, sub.EndTime)
advanced := false
for next > 0 && next <= now {
advanced = true
base = time.Unix(next, 0)
next = calcNextResetTime(base, plan, sub.EndTime)
}
if !advanced {
// keep next reset time in sync if missing
if item.NextResetTime == 0 && next > 0 {
item.NextResetTime = next
item.LastResetTime = base.Unix()
return tx.Save(item).Error
}
return nil
}
item.AmountUsed = 0
item.LastResetTime = base.Unix()
item.NextResetTime = next
return tx.Save(item).Error
}
// PreConsumeUserSubscription finds a valid active subscription item and increments amount_used.
// quotaType=0 => consume quota units; quotaType=1 => consume request count (usually 1).
func PreConsumeUserSubscription(requestId string, userId int, modelName string, quotaType int, amount int64) (*SubscriptionPreConsumeResult, error) {
if userId <= 0 {
return nil, errors.New("invalid userId")
}
if strings.TrimSpace(requestId) == "" {
return nil, errors.New("requestId is empty")
}
if modelName == "" {
return nil, errors.New("modelName is empty")
}
if amount <= 0 {
return nil, errors.New("amount must be > 0")
}
now := GetDBTimestamp()
returnValue := &SubscriptionPreConsumeResult{}
err := DB.Transaction(func(tx *gorm.DB) error {
var existing SubscriptionPreConsumeRecord
query := tx.Where("request_id = ?", requestId).Limit(1).Find(&existing)
if query.Error != nil {
return query.Error
}
if query.RowsAffected > 0 {
if existing.Status == "refunded" {
return errors.New("subscription pre-consume already refunded")
}
var item UserSubscriptionItem
if err := tx.Where("id = ?", existing.UserSubscriptionItemId).First(&item).Error; err != nil {
return err
}
returnValue.UserSubscriptionId = item.UserSubscriptionId
returnValue.ItemId = item.Id
returnValue.QuotaType = item.QuotaType
returnValue.PreConsumed = existing.PreConsumed
returnValue.AmountTotal = item.AmountTotal
returnValue.AmountUsedBefore = item.AmountUsed
returnValue.AmountUsedAfter = item.AmountUsed
return nil
}
var activeSub UserSubscription
if err := tx.Where("user_id = ? AND status = ? AND end_time > ?", userId, "active", now).
Order("end_time desc, id desc").
First(&activeSub).Error; err != nil {
return errors.New("no active subscription item for this model")
}
var candidate UserSubscriptionItem
if err := tx.Where("user_subscription_id = ? AND model_name = ? AND quota_type = ?", activeSub.Id, modelName, quotaType).
Order("id desc").
First(&candidate).Error; err != nil {
return errors.New("no active subscription item for this model")
}
var item UserSubscriptionItem
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("id = ?", candidate.Id).
First(&item).Error; err != nil {
return errors.New("no active subscription item for this model")
}
var sub UserSubscription
if err := tx.Where("id = ? AND user_id = ? AND status = ? AND end_time > ?", item.UserSubscriptionId, userId, "active", now).
First(&sub).Error; err != nil {
return errors.New("no active subscription item for this model")
}
plan, err := getSubscriptionPlanByIdTx(tx, sub.PlanId)
if err != nil {
return err
}
if err := maybeResetSubscriptionItemWithPlanTx(tx, &item, &sub, plan, now); err != nil {
return err
}
usedBefore := item.AmountUsed
remain := item.AmountTotal - usedBefore
if remain < amount {
return fmt.Errorf("subscription quota insufficient, remain=%d need=%d", remain, amount)
}
record := &SubscriptionPreConsumeRecord{
RequestId: requestId,
UserId: userId,
UserSubscriptionItemId: item.Id,
PreConsumed: amount,
Status: "consumed",
}
if err := tx.Create(record).Error; err != nil {
var dup SubscriptionPreConsumeRecord
if err2 := tx.Where("request_id = ?", requestId).First(&dup).Error; err2 == nil {
if dup.Status == "refunded" {
return errors.New("subscription pre-consume already refunded")
}
returnValue.UserSubscriptionId = item.UserSubscriptionId
returnValue.ItemId = item.Id
returnValue.QuotaType = item.QuotaType
returnValue.PreConsumed = dup.PreConsumed
returnValue.AmountTotal = item.AmountTotal
returnValue.AmountUsedBefore = item.AmountUsed
returnValue.AmountUsedAfter = item.AmountUsed
return nil
}
return err
}
item.AmountUsed += amount
if err := tx.Save(&item).Error; err != nil {
return err
}
returnValue.UserSubscriptionId = item.UserSubscriptionId
returnValue.ItemId = item.Id
returnValue.QuotaType = item.QuotaType
returnValue.PreConsumed = amount
returnValue.AmountTotal = item.AmountTotal
returnValue.AmountUsedBefore = usedBefore
returnValue.AmountUsedAfter = item.AmountUsed
return nil
})
if err != nil {
return nil, err
}
return returnValue, nil
}
// RefundSubscriptionPreConsume is idempotent and refunds pre-consumed subscription quota by requestId.
func RefundSubscriptionPreConsume(requestId string) error {
if strings.TrimSpace(requestId) == "" {
return errors.New("requestId is empty")
}
return DB.Transaction(func(tx *gorm.DB) error {
var record SubscriptionPreConsumeRecord
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("request_id = ?", requestId).First(&record).Error; err != nil {
return err
}
if record.Status == "refunded" {
return nil
}
if record.PreConsumed <= 0 {
record.Status = "refunded"
return tx.Save(&record).Error
}
if err := PostConsumeUserSubscriptionDelta(record.UserSubscriptionItemId, -record.PreConsumed); err != nil {
return err
}
record.Status = "refunded"
return tx.Save(&record).Error
})
}
// ResetDueSubscriptionItems resets items whose next_reset_time has passed.
func ResetDueSubscriptionItems(limit int) (int, error) {
if limit <= 0 {
limit = 200
}
now := GetDBTimestamp()
var items []UserSubscriptionItem
if err := DB.Where("next_reset_time > 0 AND next_reset_time <= ?", now).
Order("next_reset_time asc").
Limit(limit).
Find(&items).Error; err != nil {
return 0, err
}
if len(items) == 0 {
return 0, nil
}
subIds := make([]int, 0, len(items))
subIdSet := make(map[int]struct{}, len(items))
for _, it := range items {
if it.UserSubscriptionId <= 0 {
continue
}
if _, exists := subIdSet[it.UserSubscriptionId]; exists {
continue
}
subIdSet[it.UserSubscriptionId] = struct{}{}
subIds = append(subIds, it.UserSubscriptionId)
}
subById := make(map[int]*UserSubscription, len(subIds))
if len(subIds) > 0 {
var subs []UserSubscription
if err := DB.Where("id IN ?", subIds).Find(&subs).Error; err != nil {
return 0, err
}
for i := range subs {
sub := subs[i]
subById[sub.Id] = &sub
}
}
planById := make(map[int]*SubscriptionPlan, len(subById))
for _, sub := range subById {
if sub == nil || sub.PlanId <= 0 {
continue
}
if _, exists := planById[sub.PlanId]; exists {
continue
}
plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
if err != nil {
return 0, err
}
planById[sub.PlanId] = plan
}
resetCount := 0
for _, it := range items {
sub := subById[it.UserSubscriptionId]
if sub == nil {
continue
}
plan := planById[sub.PlanId]
if plan == nil {
continue
}
err := DB.Transaction(func(tx *gorm.DB) error {
var item UserSubscriptionItem
if err := tx.Set("gorm:query_option", "FOR UPDATE").
Where("id = ? AND next_reset_time > 0 AND next_reset_time <= ?", it.Id, now).
First(&item).Error; err != nil {
return nil
}
if err := maybeResetSubscriptionItemWithPlanTx(tx, &item, sub, plan, now); err != nil {
return err
}
resetCount++
return nil
})
if err != nil {
return resetCount, err
}
}
return resetCount, nil
}
// CleanupSubscriptionPreConsumeRecords removes old idempotency records to keep table small.
func CleanupSubscriptionPreConsumeRecords(olderThanSeconds int64) (int64, error) {
if olderThanSeconds <= 0 {
olderThanSeconds = 7 * 24 * 3600
}
cutoff := GetDBTimestamp() - olderThanSeconds
res := DB.Where("updated_at < ?", cutoff).Delete(&SubscriptionPreConsumeRecord{})
return res.RowsAffected, res.Error
}
type SubscriptionPlanInfo struct {
PlanId int
PlanTitle string
}
func GetSubscriptionPlanInfoByUserSubscriptionId(userSubscriptionId int) (*SubscriptionPlanInfo, error) {
if userSubscriptionId <= 0 {
return nil, errors.New("invalid userSubscriptionId")
}
cacheKey := fmt.Sprintf("sub:%d", userSubscriptionId)
if cached, found, err := getSubscriptionPlanInfoCache().Get(cacheKey); err == nil && found {
return &cached, nil
}
var sub UserSubscription
if err := DB.Where("id = ?", userSubscriptionId).First(&sub).Error; err != nil {
return nil, err
}
plan, err := getSubscriptionPlanByIdTx(nil, sub.PlanId)
if err != nil {
return nil, err
}
info := &SubscriptionPlanInfo{
PlanId: sub.PlanId,
PlanTitle: plan.Title,
}
_ = getSubscriptionPlanInfoCache().SetWithTTL(cacheKey, *info, subscriptionPlanInfoCacheTTL())
return info, nil
}
func GetSubscriptionPlanInfoBySubscriptionItemId(itemId int) (*SubscriptionPlanInfo, error) {
if itemId <= 0 {
return nil, errors.New("invalid itemId")
}
var item UserSubscriptionItem
if err := DB.Where("id = ?", itemId).First(&item).Error; err != nil {
return nil, err
}
return GetSubscriptionPlanInfoByUserSubscriptionId(item.UserSubscriptionId)
}
// Update subscription used amount by delta (positive consume more, negative refund).
func PostConsumeUserSubscriptionDelta(itemId int, delta int64) error {
if itemId <= 0 {
return errors.New("invalid itemId")
}
if delta == 0 {
return nil
}
return DB.Transaction(func(tx *gorm.DB) error {
var item UserSubscriptionItem
if err := tx.Set("gorm:query_option", "FOR UPDATE").Where("id = ?", itemId).First(&item).Error; err != nil {
return err
}
newUsed := item.AmountUsed + delta
if newUsed < 0 {
newUsed = 0
}
if newUsed > item.AmountTotal {
return fmt.Errorf("subscription used exceeds total, used=%d total=%d", newUsed, item.AmountTotal)
}
item.AmountUsed = newUsed
return tx.Save(&item).Error
})
}