diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go index b0988d907..52b839525 100644 --- a/common/ssrf_protection.go +++ b/common/ssrf_protection.go @@ -11,16 +11,20 @@ import ( // SSRFProtection SSRF防护配置 type SSRFProtection struct { AllowPrivateIp bool - WhitelistDomains []string // domain format, e.g. example.com, *.example.com - WhitelistIps []string // CIDR format + DomainFilterMode bool // true: 白名单, false: 黑名单 + DomainList []string // domain format, e.g. example.com, *.example.com + IpFilterMode bool // true: 白名单, false: 黑名单 + IpList []string // CIDR or single IP AllowedPorts []int // 允许的端口范围 } // DefaultSSRFProtection 默认SSRF防护配置 var DefaultSSRFProtection = &SSRFProtection{ AllowPrivateIp: false, - WhitelistDomains: []string{}, - WhitelistIps: []string{}, + DomainFilterMode: true, + DomainList: []string{}, + IpFilterMode: true, + IpList: []string{}, AllowedPorts: []int{}, } @@ -138,44 +142,25 @@ func (p *SSRFProtection) isAllowedPort(port int) bool { return false } -// isAllowedPortFromRanges 从端口范围字符串检查端口是否被允许 -func isAllowedPortFromRanges(port int, portRanges []string) bool { - if len(portRanges) == 0 { - return true // 如果没有配置端口限制,则允许所有端口 - } - - allowedPorts, err := parsePortRanges(portRanges) - if err != nil { - // 如果解析失败,为安全起见拒绝访问 - return false - } - - for _, allowedPort := range allowedPorts { - if port == allowedPort { - return true - } - } - return false -} - // isDomainWhitelisted 检查域名是否在白名单中 -func (p *SSRFProtection) isDomainWhitelisted(domain string) bool { - if len(p.WhitelistDomains) == 0 { +func isDomainListed(domain string, list []string) bool { + if len(list) == 0 { return false } domain = strings.ToLower(domain) - for _, whitelistDomain := range p.WhitelistDomains { - whitelistDomain = strings.ToLower(whitelistDomain) - + for _, item := range list { + item = strings.ToLower(strings.TrimSpace(item)) + if item == "" { + continue + } // 精确匹配 - if domain == whitelistDomain { + if domain == item { return true } - // 通配符匹配 (*.example.com) - if strings.HasPrefix(whitelistDomain, "*.") { - suffix := strings.TrimPrefix(whitelistDomain, "*.") + if strings.HasPrefix(item, "*.") { + suffix := strings.TrimPrefix(item, "*.") if strings.HasSuffix(domain, "."+suffix) || domain == suffix { return true } @@ -184,13 +169,23 @@ func (p *SSRFProtection) isDomainWhitelisted(domain string) bool { return false } +func (p *SSRFProtection) isDomainAllowed(domain string) bool { + listed := isDomainListed(domain, p.DomainList) + if p.DomainFilterMode { // 白名单 + return listed + } + // 黑名单 + return !listed +} + // isIPWhitelisted 检查IP是否在白名单中 -func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool { - if len(p.WhitelistIps) == 0 { + +func isIPListed(ip net.IP, list []string) bool { + if len(list) == 0 { return false } - for _, whitelistCIDR := range p.WhitelistIps { + for _, whitelistCIDR := range list { _, network, err := net.ParseCIDR(whitelistCIDR) if err != nil { // 尝试作为单个IP处理 @@ -211,22 +206,17 @@ func (p *SSRFProtection) isIPWhitelisted(ip net.IP) bool { // IsIPAccessAllowed 检查IP是否允许访问 func (p *SSRFProtection) IsIPAccessAllowed(ip net.IP) bool { - // 如果IP在白名单中,直接允许访问(绕过私有IP检查) - if p.isIPWhitelisted(ip) { - return true + // 私有IP限制 + if isPrivateIP(ip) && !p.AllowPrivateIp { + return false } - // 如果IP白名单为空,允许所有IP(但仍需通过私有IP检查) - if len(p.WhitelistIps) == 0 { - // 检查私有IP限制 - if isPrivateIP(ip) && !p.AllowPrivateIp { - return false - } - return true + listed := isIPListed(ip, p.IpList) + if p.IpFilterMode { // 白名单 + return listed } - - // 如果IP白名单不为空且IP不在白名单中,拒绝访问 - return false + // 黑名单 + return !listed } // ValidateURL 验证URL是否安全 @@ -264,28 +254,44 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error { return fmt.Errorf("port %d is not allowed", port) } - // 检查域名白名单 - if p.isDomainWhitelisted(host) { - return nil // 白名单域名直接通过 + // 如果 host 是 IP,则跳过域名检查 + if ip := net.ParseIP(host); ip != nil { + if !p.IsIPAccessAllowed(ip) { + if isPrivateIP(ip) { + return fmt.Errorf("private IP address not allowed: %s", ip.String()) + } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s", ip.String()) + } + return fmt.Errorf("ip in blacklist: %s", ip.String()) + } + return nil } - // DNS解析获取IP地址 + // 先进行域名过滤 + if !p.isDomainAllowed(host) { + if p.DomainFilterMode { + return fmt.Errorf("domain not in whitelist: %s", host) + } + return fmt.Errorf("domain in blacklist: %s", host) + } + + // 解析域名对应IP并检查 ips, err := net.LookupIP(host) if err != nil { return fmt.Errorf("DNS resolution failed for %s: %v", host, err) } - - // 检查所有解析的IP地址 for _, ip := range ips { if !p.IsIPAccessAllowed(ip) { - if isPrivateIP(ip) { + if isPrivateIP(ip) && !p.AllowPrivateIp { return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) - } else { - return fmt.Errorf("IP address not in whitelist: %s resolves to %s", host, ip.String()) } + if p.IpFilterMode { + return fmt.Errorf("ip not in whitelist: %s resolves to %s", host, ip.String()) + } + return fmt.Errorf("ip in blacklist: %s resolves to %s", host, ip.String()) } } - return nil } @@ -295,7 +301,7 @@ func ValidateURLWithDefaults(urlStr string) error { } // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL -func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, whitelistDomains, whitelistIps, allowedPorts []string) error { +func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string) error { // 如果SSRF防护被禁用,直接返回成功 if !enableSSRFProtection { return nil @@ -309,76 +315,11 @@ func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPriva protection := &SSRFProtection{ AllowPrivateIp: allowPrivateIp, - WhitelistDomains: whitelistDomains, - WhitelistIps: whitelistIps, + DomainFilterMode: domainFilterMode, + DomainList: domainList, + IpFilterMode: ipFilterMode, + IpList: ipList, AllowedPorts: allowedPortInts, } return protection.ValidateURL(urlStr) } - -// ValidateURLWithPortRanges 直接使用端口范围字符串验证URL(更高效的版本) -func ValidateURLWithPortRanges(urlStr string, allowPrivateIp bool, whitelistDomains, whitelistIps, allowedPorts []string) error { - // 解析URL - u, err := url.Parse(urlStr) - if err != nil { - return fmt.Errorf("invalid URL format: %v", err) - } - - // 只允许HTTP/HTTPS协议 - if u.Scheme != "http" && u.Scheme != "https" { - return fmt.Errorf("unsupported protocol: %s (only http/https allowed)", u.Scheme) - } - - // 解析主机和端口 - host, portStr, err := net.SplitHostPort(u.Host) - if err != nil { - // 没有端口,使用默认端口 - host = u.Host - if u.Scheme == "https" { - portStr = "443" - } else { - portStr = "80" - } - } - - // 验证端口 - port, err := strconv.Atoi(portStr) - if err != nil { - return fmt.Errorf("invalid port: %s", portStr) - } - - if !isAllowedPortFromRanges(port, allowedPorts) { - return fmt.Errorf("port %d is not allowed", port) - } - - // 创建临时的SSRFProtection来复用域名和IP检查逻辑 - protection := &SSRFProtection{ - AllowPrivateIp: allowPrivateIp, - WhitelistDomains: whitelistDomains, - WhitelistIps: whitelistIps, - } - - // 检查域名白名单 - if protection.isDomainWhitelisted(host) { - return nil // 白名单域名直接通过 - } - - // DNS解析获取IP地址 - ips, err := net.LookupIP(host) - if err != nil { - return fmt.Errorf("DNS resolution failed for %s: %v", host, err) - } - - // 检查所有解析的IP地址 - for _, ip := range ips { - if !protection.IsIPAccessAllowed(ip) { - if isPrivateIP(ip) { - return fmt.Errorf("private IP address not allowed: %s resolves to %s", host, ip.String()) - } else { - return fmt.Errorf("IP address not in whitelist: %s resolves to %s", host, ip.String()) - } - } - } - - return nil -} diff --git a/service/download.go b/service/download.go index 43b6fe7df..c07c9e1cd 100644 --- a/service/download.go +++ b/service/download.go @@ -30,7 +30,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { // SSRF防护:验证请求URL fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { return nil, fmt.Errorf("request reject: %v", err) } @@ -59,7 +59,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, } else { // SSRF防护:验证请求URL(非Worker模式) fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { return nil, fmt.Errorf("request reject: %v", err) } diff --git a/service/user_notify.go b/service/user_notify.go index 1e9e8947c..76d15903d 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -115,7 +115,7 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { } else { // SSRF防护:验证Bark URL(非Worker模式) fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { return fmt.Errorf("request reject: %v", err) } diff --git a/service/webhook.go b/service/webhook.go index 5d9ce400a..b7fd13df6 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -89,7 +89,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error } else { // SSRF防护:验证Webhook URL(非Worker模式) fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { return fmt.Errorf("request reject: %v", err) }