mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-18 16:37:26 +00:00
Add plan-level quota reset periods and display/reset cadence in admin/UI Enforce natural reset alignment with background reset task and cleanup job Make subscription pre-consume/refund idempotent with request-scoped records and retries Use database time for consistent resets across multi-instance deployments Harden payment callbacks with locking and idempotent order completion Record subscription purchases in topup history and billing logs Optimize subscription queries and add critical composite indexes
316 lines
8.9 KiB
Go
316 lines
8.9 KiB
Go
package controller
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/QuantumNous/new-api/model"
|
|
"github.com/QuantumNous/new-api/setting"
|
|
"github.com/QuantumNous/new-api/setting/operation_setting"
|
|
"github.com/QuantumNous/new-api/setting/system_setting"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/stripe/stripe-go/v81"
|
|
"github.com/stripe/stripe-go/v81/checkout/session"
|
|
"github.com/stripe/stripe-go/v81/webhook"
|
|
"github.com/thanhpk/randstr"
|
|
)
|
|
|
|
const (
|
|
PaymentMethodStripe = "stripe"
|
|
)
|
|
|
|
var stripeAdaptor = &StripeAdaptor{}
|
|
|
|
type StripePayRequest struct {
|
|
Amount int64 `json:"amount"`
|
|
PaymentMethod string `json:"payment_method"`
|
|
}
|
|
|
|
type StripeAdaptor struct {
|
|
}
|
|
|
|
func (*StripeAdaptor) RequestAmount(c *gin.Context, req *StripePayRequest) {
|
|
if req.Amount < getStripeMinTopup() {
|
|
c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup())})
|
|
return
|
|
}
|
|
id := c.GetInt("id")
|
|
group, err := model.GetUserGroup(id, true)
|
|
if err != nil {
|
|
c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"})
|
|
return
|
|
}
|
|
payMoney := getStripePayMoney(float64(req.Amount), group)
|
|
if payMoney <= 0.01 {
|
|
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
|
return
|
|
}
|
|
c.JSON(200, gin.H{"message": "success", "data": strconv.FormatFloat(payMoney, 'f', 2, 64)})
|
|
}
|
|
|
|
func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) {
|
|
if req.PaymentMethod != PaymentMethodStripe {
|
|
c.JSON(200, gin.H{"message": "error", "data": "不支持的支付渠道"})
|
|
return
|
|
}
|
|
if req.Amount < getStripeMinTopup() {
|
|
c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", getStripeMinTopup()), "data": 10})
|
|
return
|
|
}
|
|
if req.Amount > 10000 {
|
|
c.JSON(200, gin.H{"message": "充值数量不能大于 10000", "data": 10})
|
|
return
|
|
}
|
|
|
|
id := c.GetInt("id")
|
|
user, _ := model.GetUserById(id, false)
|
|
chargedMoney := GetChargedAmount(float64(req.Amount), *user)
|
|
|
|
reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4))
|
|
referenceId := "ref_" + common.Sha1([]byte(reference))
|
|
|
|
payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount)
|
|
if err != nil {
|
|
log.Println("获取Stripe Checkout支付链接失败", err)
|
|
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
|
return
|
|
}
|
|
|
|
topUp := &model.TopUp{
|
|
UserId: id,
|
|
Amount: req.Amount,
|
|
Money: chargedMoney,
|
|
TradeNo: referenceId,
|
|
PaymentMethod: PaymentMethodStripe,
|
|
CreateTime: time.Now().Unix(),
|
|
Status: common.TopUpStatusPending,
|
|
}
|
|
err = topUp.Insert()
|
|
if err != nil {
|
|
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
|
return
|
|
}
|
|
c.JSON(200, gin.H{
|
|
"message": "success",
|
|
"data": gin.H{
|
|
"pay_link": payLink,
|
|
},
|
|
})
|
|
}
|
|
|
|
func RequestStripeAmount(c *gin.Context) {
|
|
var req StripePayRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
|
return
|
|
}
|
|
stripeAdaptor.RequestAmount(c, &req)
|
|
}
|
|
|
|
func RequestStripePay(c *gin.Context) {
|
|
var req StripePayRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.JSON(200, gin.H{"message": "error", "data": "参数错误"})
|
|
return
|
|
}
|
|
stripeAdaptor.RequestPay(c, &req)
|
|
}
|
|
|
|
func StripeWebhook(c *gin.Context) {
|
|
payload, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
log.Printf("解析Stripe Webhook参数失败: %v\n", err)
|
|
c.AbortWithStatus(http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
signature := c.GetHeader("Stripe-Signature")
|
|
endpointSecret := setting.StripeWebhookSecret
|
|
event, err := webhook.ConstructEventWithOptions(payload, signature, endpointSecret, webhook.ConstructEventOptions{
|
|
IgnoreAPIVersionMismatch: true,
|
|
})
|
|
|
|
if err != nil {
|
|
log.Printf("Stripe Webhook验签失败: %v\n", err)
|
|
c.AbortWithStatus(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
switch event.Type {
|
|
case stripe.EventTypeCheckoutSessionCompleted:
|
|
sessionCompleted(event)
|
|
case stripe.EventTypeCheckoutSessionExpired:
|
|
sessionExpired(event)
|
|
default:
|
|
log.Printf("不支持的Stripe Webhook事件类型: %s\n", event.Type)
|
|
}
|
|
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
func sessionCompleted(event stripe.Event) {
|
|
customerId := event.GetObjectValue("customer")
|
|
referenceId := event.GetObjectValue("client_reference_id")
|
|
status := event.GetObjectValue("status")
|
|
if "complete" != status {
|
|
log.Println("错误的Stripe Checkout完成状态:", status, ",", referenceId)
|
|
return
|
|
}
|
|
|
|
// Try complete subscription order first
|
|
LockOrder(referenceId)
|
|
defer UnlockOrder(referenceId)
|
|
payload := map[string]any{
|
|
"customer": customerId,
|
|
"amount_total": event.GetObjectValue("amount_total"),
|
|
"currency": strings.ToUpper(event.GetObjectValue("currency")),
|
|
"event_type": string(event.Type),
|
|
}
|
|
if err := model.CompleteSubscriptionOrder(referenceId, jsonString(payload)); err == nil {
|
|
return
|
|
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
|
log.Println("complete subscription order failed:", err.Error(), referenceId)
|
|
return
|
|
}
|
|
|
|
err := model.Recharge(referenceId, customerId)
|
|
if err != nil {
|
|
log.Println(err.Error(), referenceId)
|
|
return
|
|
}
|
|
|
|
total, _ := strconv.ParseFloat(event.GetObjectValue("amount_total"), 64)
|
|
currency := strings.ToUpper(event.GetObjectValue("currency"))
|
|
log.Printf("收到款项:%s, %.2f(%s)", referenceId, total/100, currency)
|
|
}
|
|
|
|
func sessionExpired(event stripe.Event) {
|
|
referenceId := event.GetObjectValue("client_reference_id")
|
|
status := event.GetObjectValue("status")
|
|
if "expired" != status {
|
|
log.Println("错误的Stripe Checkout过期状态:", status, ",", referenceId)
|
|
return
|
|
}
|
|
|
|
if len(referenceId) == 0 {
|
|
log.Println("未提供支付单号")
|
|
return
|
|
}
|
|
|
|
// Subscription order expiration
|
|
LockOrder(referenceId)
|
|
defer UnlockOrder(referenceId)
|
|
if err := model.ExpireSubscriptionOrder(referenceId); err == nil {
|
|
return
|
|
} else if err != nil && !errors.Is(err, model.ErrSubscriptionOrderNotFound) {
|
|
log.Println("过期订阅订单失败", referenceId, ", err:", err.Error())
|
|
return
|
|
}
|
|
|
|
topUp := model.GetTopUpByTradeNo(referenceId)
|
|
if topUp == nil {
|
|
log.Println("充值订单不存在", referenceId)
|
|
return
|
|
}
|
|
|
|
if topUp.Status != common.TopUpStatusPending {
|
|
log.Println("充值订单状态错误", referenceId)
|
|
}
|
|
|
|
topUp.Status = common.TopUpStatusExpired
|
|
err := topUp.Update()
|
|
if err != nil {
|
|
log.Println("过期充值订单失败", referenceId, ", err:", err.Error())
|
|
return
|
|
}
|
|
|
|
log.Println("充值订单已过期", referenceId)
|
|
}
|
|
|
|
func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) {
|
|
if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") {
|
|
return "", fmt.Errorf("无效的Stripe API密钥")
|
|
}
|
|
|
|
stripe.Key = setting.StripeApiSecret
|
|
|
|
params := &stripe.CheckoutSessionParams{
|
|
ClientReferenceID: stripe.String(referenceId),
|
|
SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"),
|
|
CancelURL: stripe.String(system_setting.ServerAddress + "/console/topup"),
|
|
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
|
{
|
|
Price: stripe.String(setting.StripePriceId),
|
|
Quantity: stripe.Int64(amount),
|
|
},
|
|
},
|
|
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
|
|
AllowPromotionCodes: stripe.Bool(setting.StripePromotionCodesEnabled),
|
|
}
|
|
|
|
if "" == customerId {
|
|
if "" != email {
|
|
params.CustomerEmail = stripe.String(email)
|
|
}
|
|
|
|
params.CustomerCreation = stripe.String(string(stripe.CheckoutSessionCustomerCreationAlways))
|
|
} else {
|
|
params.Customer = stripe.String(customerId)
|
|
}
|
|
|
|
result, err := session.New(params)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return result.URL, nil
|
|
}
|
|
|
|
func GetChargedAmount(count float64, user model.User) float64 {
|
|
topUpGroupRatio := common.GetTopupGroupRatio(user.Group)
|
|
if topUpGroupRatio == 0 {
|
|
topUpGroupRatio = 1
|
|
}
|
|
|
|
return count * topUpGroupRatio
|
|
}
|
|
|
|
func getStripePayMoney(amount float64, group string) float64 {
|
|
originalAmount := amount
|
|
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
|
amount = amount / common.QuotaPerUnit
|
|
}
|
|
// Using float64 for monetary calculations is acceptable here due to the small amounts involved
|
|
topupGroupRatio := common.GetTopupGroupRatio(group)
|
|
if topupGroupRatio == 0 {
|
|
topupGroupRatio = 1
|
|
}
|
|
// apply optional preset discount by the original request amount (if configured), default 1.0
|
|
discount := 1.0
|
|
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
|
if ds > 0 {
|
|
discount = ds
|
|
}
|
|
}
|
|
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
|
|
return payMoney
|
|
}
|
|
|
|
func getStripeMinTopup() int64 {
|
|
minTopup := setting.StripeMinTopUp
|
|
if operation_setting.GetQuotaDisplayType() == operation_setting.QuotaDisplayTypeTokens {
|
|
minTopup = minTopup * int(common.QuotaPerUnit)
|
|
}
|
|
return int64(minTopup)
|
|
}
|