diff --git a/common/ssrf_protection.go b/common/ssrf_protection.go index e48ca0e08..40d3b10b8 100644 --- a/common/ssrf_protection.go +++ b/common/ssrf_protection.go @@ -10,12 +10,13 @@ import ( // SSRFProtection SSRF防护配置 type SSRFProtection struct { - AllowPrivateIp bool - DomainFilterMode bool // true: 白名单, false: 黑名单 - DomainList []string // domain format, e.g. example.com, *.example.com - IpFilterMode bool // true: 白名单, false: 黑名单 - IpList []string // CIDR or single IP - AllowedPorts []int // 允许的端口范围 + AllowPrivateIp bool + DomainFilterMode bool // true: 白名单, false: 黑名单 + DomainList []string // domain format, e.g. example.com, *.example.com + IpFilterMode bool // true: 白名单, false: 黑名单 + IpList []string // CIDR or single IP + AllowedPorts []int // 允许的端口范围 + ApplyIPFilterForDomain bool // 对域名启用IP过滤 } // DefaultSSRFProtection 默认SSRF防护配置 @@ -276,6 +277,11 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error { return fmt.Errorf("domain in blacklist: %s", host) } + // 若未启用对域名应用IP过滤,则到此通过 + if !p.ApplyIPFilterForDomain { + return nil + } + // 解析域名对应IP并检查 ips, err := net.LookupIP(host) if err != nil { @@ -296,7 +302,7 @@ func (p *SSRFProtection) ValidateURL(urlStr string) error { } // ValidateURLWithFetchSetting 使用FetchSetting配置验证URL -func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string) error { +func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPrivateIp bool, domainFilterMode bool, ipFilterMode bool, domainList, ipList, allowedPorts []string, applyIPFilterForDomain bool) error { // 如果SSRF防护被禁用,直接返回成功 if !enableSSRFProtection { return nil @@ -309,12 +315,13 @@ func ValidateURLWithFetchSetting(urlStr string, enableSSRFProtection, allowPriva } protection := &SSRFProtection{ - AllowPrivateIp: allowPrivateIp, - DomainFilterMode: domainFilterMode, - DomainList: domainList, - IpFilterMode: ipFilterMode, - IpList: ipList, - AllowedPorts: allowedPortInts, + AllowPrivateIp: allowPrivateIp, + DomainFilterMode: domainFilterMode, + DomainList: domainList, + IpFilterMode: ipFilterMode, + IpList: ipList, + AllowedPorts: allowedPortInts, + ApplyIPFilterForDomain: applyIPFilterForDomain, } return protection.ValidateURL(urlStr) } diff --git a/service/download.go b/service/download.go index c07c9e1cd..036c43af8 100644 --- a/service/download.go +++ b/service/download.go @@ -30,7 +30,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { // SSRF防护:验证请求URL fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(req.URL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { return nil, fmt.Errorf("request reject: %v", err) } @@ -59,7 +59,7 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, } else { // SSRF防护:验证请求URL(非Worker模式) fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(originUrl, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { return nil, fmt.Errorf("request reject: %v", err) } diff --git a/service/user_notify.go b/service/user_notify.go index 76d15903d..fba12d9db 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -115,7 +115,7 @@ func sendBarkNotify(barkURL string, data dto.Notify) error { } else { // SSRF防护:验证Bark URL(非Worker模式) fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(finalURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { return fmt.Errorf("request reject: %v", err) } diff --git a/service/webhook.go b/service/webhook.go index b7fd13df6..c678b8634 100644 --- a/service/webhook.go +++ b/service/webhook.go @@ -89,7 +89,7 @@ func SendWebhookNotify(webhookURL string, secret string, data dto.Notify) error } else { // SSRF防护:验证Webhook URL(非Worker模式) fetchSetting := system_setting.GetFetchSetting() - if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts); err != nil { + if err := common.ValidateURLWithFetchSetting(webhookURL, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { return fmt.Errorf("request reject: %v", err) } diff --git a/setting/system_setting/fetch_setting.go b/setting/system_setting/fetch_setting.go index 5277e1033..3c7f1e059 100644 --- a/setting/system_setting/fetch_setting.go +++ b/setting/system_setting/fetch_setting.go @@ -3,23 +3,25 @@ package system_setting import "one-api/setting/config" type FetchSetting struct { - EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护 - AllowPrivateIp bool `json:"allow_private_ip"` - DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式 - IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式 - DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com - IpList []string `json:"ip_list"` // CIDR format - AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000 + EnableSSRFProtection bool `json:"enable_ssrf_protection"` // 是否启用SSRF防护 + AllowPrivateIp bool `json:"allow_private_ip"` + DomainFilterMode bool `json:"domain_filter_mode"` // 域名过滤模式,true: 白名单模式,false: 黑名单模式 + IpFilterMode bool `json:"ip_filter_mode"` // IP过滤模式,true: 白名单模式,false: 黑名单模式 + DomainList []string `json:"domain_list"` // domain format, e.g. example.com, *.example.com + IpList []string `json:"ip_list"` // CIDR format + AllowedPorts []string `json:"allowed_ports"` // port range format, e.g. 80, 443, 8000-9000 + ApplyIPFilterForDomain bool `json:"apply_ip_filter_for_domain"` // 对域名启用IP过滤(实验性) } var defaultFetchSetting = FetchSetting{ - EnableSSRFProtection: true, // 默认开启SSRF防护 - AllowPrivateIp: false, - DomainFilterMode: true, - IpFilterMode: true, - DomainList: []string{}, - IpList: []string{}, - AllowedPorts: []string{"80", "443", "8080", "8443"}, + EnableSSRFProtection: true, // 默认开启SSRF防护 + AllowPrivateIp: false, + DomainFilterMode: true, + IpFilterMode: true, + DomainList: []string{}, + IpList: []string{}, + AllowedPorts: []string{"80", "443", "8080", "8443"}, + ApplyIPFilterForDomain: false, } func init() { diff --git a/web/src/components/settings/SystemSetting.jsx b/web/src/components/settings/SystemSetting.jsx index ebe4084be..a1d26a4ad 100644 --- a/web/src/components/settings/SystemSetting.jsx +++ b/web/src/components/settings/SystemSetting.jsx @@ -97,6 +97,7 @@ const SystemSetting = () => { 'fetch_setting.domain_list': [], 'fetch_setting.ip_list': [], 'fetch_setting.allowed_ports': [], + 'fetch_setting.apply_ip_filter_for_domain': false, }); const [originInputs, setOriginInputs] = useState({}); @@ -132,6 +133,7 @@ const SystemSetting = () => { case 'fetch_setting.enable_ssrf_protection': case 'fetch_setting.domain_filter_mode': case 'fetch_setting.ip_filter_mode': + case 'fetch_setting.apply_ip_filter_for_domain': item.value = toBoolean(item.value); break; case 'fetch_setting.domain_list': @@ -724,6 +726,17 @@ const SystemSetting = () => { style={{ marginTop: 16 }} >