mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-05-01 02:41:46 +00:00
feat: ssrf支持域名和ip黑白名单过滤模式
This commit is contained in:
@@ -11,16 +11,20 @@ import (
|
|||||||
// SSRFProtection SSRF防护配置
|
// SSRFProtection SSRF防护配置
|
||||||
type SSRFProtection struct {
|
type SSRFProtection struct {
|
||||||
AllowPrivateIp bool
|
AllowPrivateIp bool
|
||||||
WhitelistDomains []string // domain format, e.g. example.com, *.example.com
|
DomainFilterMode bool // true: 白名单, false: 黑名单
|
||||||
WhitelistIps []string // CIDR format
|
DomainList []string // domain format, e.g. example.com, *.example.com
|
||||||
|
IpFilterMode bool // true: 白名单, false: 黑名单
|
||||||
|
IpList []string // CIDR or single IP
|
||||||
AllowedPorts []int // 允许的端口范围
|
AllowedPorts []int // 允许的端口范围
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultSSRFProtection 默认SSRF防护配置
|
// DefaultSSRFProtection 默认SSRF防护配置
|
||||||
var DefaultSSRFProtection = &SSRFProtection{
|
var DefaultSSRFProtection = &SSRFProtection{
|
||||||
AllowPrivateIp: false,
|
AllowPrivateIp: false,
|
||||||
WhitelistDomains: []string{},
|
DomainFilterMode: true,
|
||||||
WhitelistIps: []string{},
|
DomainList: []string{},
|
||||||
|
IpFilterMode: true,
|
||||||
|
IpList: []string{},
|
||||||
AllowedPorts: []int{},
|
AllowedPorts: []int{},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,44 +142,25 @@ func (p *SSRFProtection) isAllowedPort(port int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isAllowedPortFromRanges 从端口范围字符串检查端口是否被允许
|
|
||||||
func isAllowedPortFromRanges(port int, portRanges []string) bool {
|
|
||||||
if len(portRanges) == 0 {
|
|
||||||
return true // 如果没有配置端口限制,则允许所有端口
|
|
||||||
}
|
|
||||||
|
|
||||||
allowedPorts, err := parsePortRanges(portRanges)
|
|
||||||
if err != nil {
|
|
||||||
// 如果解析失败,为安全起见拒绝访问
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, allowedPort := range allowedPorts {
|
|
||||||
if port == allowedPort {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// isDomainWhitelisted 检查域名是否在白名单中
|
// isDomainWhitelisted 检查域名是否在白名单中
|
||||||
func (p *SSRFProtection) isDomainWhitelisted(domain string) bool {
|
func isDomainListed(domain string, list []string) bool {
|
||||||
if len(p.WhitelistDomains) == 0 {
|
if len(list) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
domain = strings.ToLower(domain)
|
domain = strings.ToLower(domain)
|
||||||
for _, whitelistDomain := range p.WhitelistDomains {
|
for _, item := range list {
|
||||||
whitelistDomain = strings.ToLower(whitelistDomain)
|
item = strings.ToLower(strings.TrimSpace(item))
|
||||||
|
if item == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 精确匹配
|
// 精确匹配
|
||||||
if domain == whitelistDomain {
|
if domain == item {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 通配符匹配 (*.example.com)
|
// 通配符匹配 (*.example.com)
|
||||||
if strings.HasPrefix(whitelistDomain, "*.") {
|
if strings.HasPrefix(item, "*.") {
|
||||||
suffix := strings.TrimPrefix(whitelistDomain, "*.")
|
suffix := strings.TrimPrefix(item, "*.")
|
||||||
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
|
if strings.HasSuffix(domain, "."+suffix) || domain == suffix {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -184,13 +169,23 @@ func (p *SSRFProtection) isDomainWhitelisted(domain string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *SSRFProtection) isDomainAllowed(domain string) bool {
|
||||||
|
listed := isDomainListed(domain, p.DomainList)
|
||||||
|
if p.DomainFilterMode { // 白名单
|
||||||
|
return listed
|
||||||
|
}
|
||||||
|
// 黑名单
|
||||||
|
return !listed
|
||||||
|
}
|
||||||
|
|
||||||
// isIPWhitelisted 检查IP是否在白名单中
|
// isIPWhitelisted 检查IP是否在白名单中
|
||||||
func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool {
|
|
||||||
if len(p.WhitelistIps) == 0 {
|
func isIPListed(ip net.IP, list []string) bool {
|
||||||
|
if len(list) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, whitelistCIDR := range p.WhitelistIps {
|
for _, whitelistCIDR := range list {
|
||||||
_, network, err := net.ParseCIDR(whitelistCIDR)
|
_, network, err := net.ParseCIDR(whitelistCIDR)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 尝试作为单个IP处理
|
// 尝试作为单个IP处理
|
||||||
@@ -211,22 +206,17 @@ func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool {
|
|||||||
|
|
||||||
// IsIPAccessAllowed 检查IP是否允许访问
|
// IsIPAccessAllowed 检查IP是否允许访问
|
||||||
func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
|
func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool {
|
||||||
// 如果IP在白名单中,直接允许访问(绕过私有IP检查)
|
// 私有IP限制
|
||||||
if p.isIPWhitelisted(ip) {
|
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果IP白名单为空,允许所有IP(但仍需通过私有IP检查)
|
listed := isIPListed(ip, p.IpList)
|
||||||
if len(p.WhitelistIps) == 0 {
|
if p.IpFilterMode { // 白名单
|
||||||
// 检查私有IP限制
|
return listed
|
||||||
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
// 黑名单
|
||||||
// 如果IP白名单不为空且IP不在白名单中,拒绝访问
|
return !listed
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateURL 验证URL是否安全
|
// ValidateURL 验证URL是否安全
|
||||||
@@ -264,28 +254,44 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error {
|
|||||||
return fmt.Errorf("port %d is not allowed", port)
|
return fmt.Errorf("port %d is not allowed", port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查域名白名单
|
// 如果 host 是 IP,则跳过域名检查
|
||||||
if p.isDomainWhitelisted(host) {
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
return nil // 白名单域名直接通过
|
if !p.IsIPAccessAllowed(ip) {
|
||||||
|
if isPrivateIP(ip) {
|
||||||
|
return fmt.Errorf("private IP address not allowed: %s", ip.String())
|
||||||
|
}
|
||||||
|
if p.IpFilterMode {
|
||||||
|
return fmt.Errorf("ip not in whitelist: %s", ip.String())
|
||||||
|
}
|
||||||
|
return fmt.Errorf("ip in blacklist: %s", ip.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNS解析获取IP地址
|
// 先进行域名过滤
|
||||||
|
if !p.isDomainAllowed(host) {
|
||||||
|
if p.DomainFilterMode {
|
||||||
|
return fmt.Errorf("domain not in whitelist: %s", host)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("domain in blacklist: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析域名对应IP并检查
|
||||||
ips, err := net.LookupIP(host)
|
ips, err := net.LookupIP(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
|
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查所有解析的IP地址
|
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
if !p.IsIPAccessAllowed(ip) {
|
if !p.IsIPAccessAllowed(ip) {
|
||||||
if isPrivateIP(ip) {
|
if isPrivateIP(ip) && !p.AllowPrivateIp {
|
||||||
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
|
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
|
||||||
} else {
|
|
||||||
return fmt.Errorf("IP address not in whitelist: %s resolves to %s", host, ip.String())
|
|
||||||
}
|
}
|
||||||
|
if p.IpFilterMode {
|
||||||
|
return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String())
|
||||||
|
}
|
||||||
|
return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,7 +301,7 @@ func ValidateURLWithDefaults(urlStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
|
// ValidateURLWithFetchSetting 使用FetchSetting配置验证URL
|
||||||
func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, whitelistDomains, whitelistIps, allowedPorts []string) error {
|
func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string) error {
|
||||||
// 如果SSRF防护被禁用,直接返回成功
|
// 如果SSRF防护被禁用,直接返回成功
|
||||||
if !enableSSRFProtection {
|
if !enableSSRFProtection {
|
||||||
return nil
|
return nil
|
||||||
@@ -309,76 +315,11 @@ func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPriva
|
|||||||
|
|
||||||
protection := &SSRFProtection{
|
protection := &SSRFProtection{
|
||||||
AllowPrivateIp: allowPrivateIp,
|
AllowPrivateIp: allowPrivateIp,
|
||||||
WhitelistDomains: whitelistDomains,
|
DomainFilterMode: domainFilterMode,
|
||||||
WhitelistIps: whitelistIps,
|
DomainList: domainList,
|
||||||
|
IpFilterMode: ipFilterMode,
|
||||||
|
IpList: ipList,
|
||||||
AllowedPorts: allowedPortInts,
|
AllowedPorts: allowedPortInts,
|
||||||
}
|
}
|
||||||
return protection.ValidateURL(urlStr)
|
return protection.ValidateURL(urlStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateURLWithPortRanges 直接使用端口范围字符串验证URL(更高效的版本)
|
|
||||||
func ValidateURLWithPortRanges(urlStr string, allowPrivateIp bool, whitelistDomains, whitelistIps, allowedPorts []string) error {
|
|
||||||
// 解析URL
|
|
||||||
u, err := url.Parse(urlStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid URL format: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只允许HTTP/HTTPS协议
|
|
||||||
if u.Scheme != "http" && u.Scheme != "https" {
|
|
||||||
return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析主机和端口
|
|
||||||
host, portStr, err := net.SplitHostPort(u.Host)
|
|
||||||
if err != nil {
|
|
||||||
// 没有端口,使用默认端口
|
|
||||||
host = u.Host
|
|
||||||
if u.Scheme == "https" {
|
|
||||||
portStr = "443"
|
|
||||||
} else {
|
|
||||||
portStr = "80"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证端口
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid port: %s", portStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isAllowedPortFromRanges(port, allowedPorts) {
|
|
||||||
return fmt.Errorf("port %d is not allowed", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建临时的SSRFProtection来复用域名和IP检查逻辑
|
|
||||||
protection := &SSRFProtection{
|
|
||||||
AllowPrivateIp: allowPrivateIp,
|
|
||||||
WhitelistDomains: whitelistDomains,
|
|
||||||
WhitelistIps: whitelistIps,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查域名白名单
|
|
||||||
if protection.isDomainWhitelisted(host) {
|
|
||||||
return nil // 白名单域名直接通过
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNS解析获取IP地址
|
|
||||||
ips, err := net.LookupIP(host)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("DNS resolution failed for %s: %v", host, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查所有解析的IP地址
|
|
||||||
for _, ip := range ips {
|
|
||||||
if !protection.IsIPAccessAllowed(ip) {
|
|
||||||
if isPrivateIP(ip) {
|
|
||||||
return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String())
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("IP address not in whitelist: %s resolves to %s", host, ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) {
|
|||||||
|
|
||||||
// SSRF防护:验证请求URL
|
// SSRF防护:验证请求URL
|
||||||
fetchSetting := system_setting.GetFetchSetting()
|
fetchSetting := system_setting.GetFetchSetting()
|
||||||
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
||||||
return nil, fmt.Errorf("request reject: %v", err)
|
return nil, fmt.Errorf("request reject: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response,
|
|||||||
} else {
|
} else {
|
||||||
// SSRF防护:验证请求URL(非Worker模式)
|
// SSRF防护:验证请求URL(非Worker模式)
|
||||||
fetchSetting := system_setting.GetFetchSetting()
|
fetchSetting := system_setting.GetFetchSetting()
|
||||||
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
||||||
return nil, fmt.Errorf("request reject: %v", err)
|
return nil, fmt.Errorf("request reject: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func sendBarkNotify(barkURL string, data dto.Notify) error {
|
|||||||
} else {
|
} else {
|
||||||
// SSRF防护:验证Bark URL(非Worker模式)
|
// SSRF防护:验证Bark URL(非Worker模式)
|
||||||
fetchSetting := system_setting.GetFetchSetting()
|
fetchSetting := system_setting.GetFetchSetting()
|
||||||
if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
||||||
return fmt.Errorf("request reject: %v", err)
|
return fmt.Errorf("request reject: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error
|
|||||||
} else {
|
} else {
|
||||||
// SSRF防护:验证Webhook URL(非Worker模式)
|
// SSRF防护:验证Webhook URL(非Worker模式)
|
||||||
fetchSetting := system_setting.GetFetchSetting()
|
fetchSetting := system_setting.GetFetchSetting()
|
||||||
if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil {
|
||||||
return fmt.Errorf("request reject: %v", err)
|
return fmt.Errorf("request reject: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user