diff --git a/.env.example b/.env.example index f4b9d02ee..0a64758dd 100644 --- a/.env.example +++ b/.env.example @@ -85,3 +85,8 @@ LINUX_DO_USER_ENDPOINT=https://connect.linux.do/api/user # 节点类型 # 如果是主节点则为master # NODE_TYPE=master + +# 可信任重定向域名列表(逗号分隔,支持子域名匹配) +# 用于验证支付成功/取消回调URL的域名安全性 +# 示例: example.com,myapp.io 将允许 example.com, sub.example.com, myapp.io 等 +# TRUSTED_REDIRECT_DOMAINS=example.com,myapp.io diff --git a/common/init.go b/common/init.go index 9501ce3be..cf62f408e 100644 --- a/common/init.go +++ b/common/init.go @@ -159,4 +159,17 @@ func initConstantEnv() { } constant.TaskPricePatches = taskPricePatches } + + // Initialize trusted redirect domains for URL validation + trustedDomainsStr := GetEnvOrDefaultString("TRUSTED_REDIRECT_DOMAINS", "") + var trustedDomains []string + domains := strings.Split(trustedDomainsStr, ",") + for _, domain := range domains { + trimmedDomain := strings.TrimSpace(domain) + if trimmedDomain != "" { + // Normalize domain to lowercase + trustedDomains = append(trustedDomains, strings.ToLower(trimmedDomain)) + } + } + constant.TrustedRedirectDomains = trustedDomains } diff --git a/common/url_validator.go b/common/url_validator.go new file mode 100644 index 000000000..151f643f1 --- /dev/null +++ b/common/url_validator.go @@ -0,0 +1,39 @@ +package common + +import ( + "fmt" + "net/url" + "strings" + + "github.com/QuantumNous/new-api/constant" +) + +// ValidateRedirectURL validates that a redirect URL is safe to use. +// It checks that: +// - The URL is properly formatted +// - The scheme is either http or https +// - The domain is in the trusted domains list (exact match or subdomain) +// +// Returns nil if the URL is valid and trusted, otherwise returns an error +// describing why the validation failed. +func ValidateRedirectURL(rawURL string) error { + // Parse the URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL format: %s", err.Error()) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("invalid URL scheme: only http and https are allowed") + } + + domain := strings.ToLower(parsedURL.Hostname()) + + for _, trustedDomain := range constant.TrustedRedirectDomains { + if domain == trustedDomain || strings.HasSuffix(domain, "."+trustedDomain) { + return nil + } + } + + return fmt.Errorf("domain %s is not in the trusted domains list", domain) +} diff --git a/common/url_validator_test.go b/common/url_validator_test.go new file mode 100644 index 000000000..b87b6787e --- /dev/null +++ b/common/url_validator_test.go @@ -0,0 +1,134 @@ +package common + +import ( + "testing" + + "github.com/QuantumNous/new-api/constant" +) + +func TestValidateRedirectURL(t *testing.T) { + // Save original trusted domains and restore after test + originalDomains := constant.TrustedRedirectDomains + defer func() { + constant.TrustedRedirectDomains = originalDomains + }() + + tests := []struct { + name string + url string + trustedDomains []string + wantErr bool + errContains string + }{ + // Valid cases + { + name: "exact domain match with https", + url: "https://example.com/success", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + { + name: "exact domain match with http", + url: "http://example.com/callback", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + { + name: "subdomain match", + url: "https://sub.example.com/success", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + { + name: "case insensitive domain", + url: "https://EXAMPLE.COM/success", + trustedDomains: []string{"example.com"}, + wantErr: false, + }, + + // Invalid cases - untrusted domain + { + name: "untrusted domain", + url: "https://evil.com/phishing", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "not in the trusted domains list", + }, + { + name: "suffix attack - fakeexample.com", + url: "https://fakeexample.com/success", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "not in the trusted domains list", + }, + { + name: "empty trusted domains list", + url: "https://example.com/success", + trustedDomains: []string{}, + wantErr: true, + errContains: "not in the trusted domains list", + }, + + // Invalid cases - scheme + { + name: "javascript scheme", + url: "javascript:alert('xss')", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "invalid URL scheme", + }, + { + name: "data scheme", + url: "data:text/html,", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "invalid URL scheme", + }, + + // Edge cases + { + name: "empty URL", + url: "", + trustedDomains: []string{"example.com"}, + wantErr: true, + errContains: "invalid URL scheme", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up trusted domains for this test case + constant.TrustedRedirectDomains = tt.trustedDomains + + err := ValidateRedirectURL(tt.url) + + if tt.wantErr { + if err == nil { + t.Errorf("ValidateRedirectURL(%q) expected error containing %q, got nil", tt.url, tt.errContains) + return + } + if tt.errContains != "" && !contains(err.Error(), tt.errContains) { + t.Errorf("ValidateRedirectURL(%q) error = %q, want error containing %q", tt.url, err.Error(), tt.errContains) + } + } else { + if err != nil { + t.Errorf("ValidateRedirectURL(%q) unexpected error: %v", tt.url, err) + } + } + }) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/constant/env.go b/constant/env.go index c561c207d..873427a7c 100644 --- a/constant/env.go +++ b/constant/env.go @@ -20,3 +20,7 @@ var TaskQueryLimit int // temporary variable for sora patch, will be removed in future var TaskPricePatches []string + +// TrustedRedirectDomains is a list of trusted domains for redirect URL validation. +// Domains support subdomain matching (e.g., "example.com" matches "sub.example.com"). +var TrustedRedirectDomains []string diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index 337ff8e73..2f5d7e722 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -28,9 +28,18 @@ const ( var stripeAdaptor = &StripeAdaptor{} +// StripePayRequest represents a payment request for Stripe checkout. type StripePayRequest struct { - Amount int64 `json:"amount"` + // Amount is the quantity of units to purchase. + Amount int64 `json:"amount"` + // PaymentMethod specifies the payment method (e.g., "stripe"). PaymentMethod string `json:"payment_method"` + // SuccessURL is the optional custom URL to redirect after successful payment. + // If empty, defaults to the server's console log page. + SuccessURL string `json:"success_url,omitempty"` + // CancelURL is the optional custom URL to redirect when payment is canceled. + // If empty, defaults to the server's console topup page. + CancelURL string `json:"cancel_url,omitempty"` } type StripeAdaptor struct { @@ -69,6 +78,16 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { return } + if req.SuccessURL != "" && common.ValidateRedirectURL(req.SuccessURL) != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": "支付成功重定向URL不在可信任域名列表中", "data": ""}) + return + } + + if req.CancelURL != "" && common.ValidateRedirectURL(req.CancelURL) != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": "支付取消重定向URL不在可信任域名列表中", "data": ""}) + return + } + id := c.GetInt("id") user, _ := model.GetUserById(id, false) chargedMoney := GetChargedAmount(float64(req.Amount), *user) @@ -76,7 +95,7 @@ func (*StripeAdaptor) RequestPay(c *gin.Context, req *StripePayRequest) { reference := fmt.Sprintf("new-api-ref-%d-%d-%s", user.Id, time.Now().UnixMilli(), randstr.String(4)) referenceId := "ref_" + common.Sha1([]byte(reference)) - payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount) + payLink, err := genStripeLink(referenceId, user.StripeCustomer, user.Email, req.Amount, req.SuccessURL, req.CancelURL) if err != nil { log.Println("获取Stripe Checkout支付链接失败", err) c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"}) @@ -210,17 +229,37 @@ func sessionExpired(event stripe.Event) { log.Println("充值订单已过期", referenceId) } -func genStripeLink(referenceId string, customerId string, email string, amount int64) (string, error) { +// genStripeLink generates a Stripe Checkout session URL for payment. +// It creates a new checkout session with the specified parameters and returns the payment URL. +// +// Parameters: +// - referenceId: unique reference identifier for the transaction +// - customerId: existing Stripe customer ID (empty string if new customer) +// - email: customer email address for new customer creation +// - amount: quantity of units to purchase +// - successURL: custom URL to redirect after successful payment (empty for default) +// - cancelURL: custom URL to redirect when payment is canceled (empty for default) +// +// Returns the checkout session URL or an error if the session creation fails. +func genStripeLink(referenceId string, customerId string, email string, amount int64, successURL string, cancelURL string) (string, error) { if !strings.HasPrefix(setting.StripeApiSecret, "sk_") && !strings.HasPrefix(setting.StripeApiSecret, "rk_") { return "", fmt.Errorf("无效的Stripe API密钥") } stripe.Key = setting.StripeApiSecret + // Use custom URLs if provided, otherwise use defaults + if successURL == "" { + successURL = system_setting.ServerAddress + "/console/log" + } + if cancelURL == "" { + cancelURL = system_setting.ServerAddress + "/console/topup" + } + params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(system_setting.ServerAddress + "/console/log"), - CancelURL: stripe.String(system_setting.ServerAddress + "/console/topup"), + SuccessURL: stripe.String(successURL), + CancelURL: stripe.String(cancelURL), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(setting.StripePriceId),