diff --git a/service/download.go b/service/download.go index 036c43af8..09464390e 100644 --- a/service/download.go +++ b/service/download.go @@ -45,7 +45,7 @@ func DoWorkerRequest(req *WorkerRequest) (*http.Response, error) { return nil, fmt.Errorf("failed to marshal worker payload: %v", err) } - return http.Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) + return GetHttpClient().Post(workerUrl, "application/json", bytes.NewBuffer(workerPayload)) } func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, err error) { @@ -64,6 +64,6 @@ func DoDownloadRequest(originUrl string, reason ...string) (resp *http.Response, } common.SysLog(fmt.Sprintf("downloading from origin: %s, reason: %s", common.MaskSensitiveInfo(originUrl), strings.Join(reason, ", "))) - return http.Get(originUrl) + return GetHttpClient().Get(originUrl) } } diff --git a/service/http_client.go b/service/http_client.go index c1d6880c9..79eca4a98 100644 --- a/service/http_client.go +++ b/service/http_client.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "one-api/common" + "one-api/setting/system_setting" "sync" "time" @@ -19,12 +20,27 @@ var ( proxyClients = make(map[string]*http.Client) ) +func checkRedirect(req *http.Request, via []*http.Request) error { + fetchSetting := system_setting.GetFetchSetting() + urlStr := req.URL.String() + if err := common.ValidateURLWithFetchSetting(urlStr, fetchSetting.EnableSSRFProtection, fetchSetting.AllowPrivateIp, fetchSetting.DomainFilterMode, fetchSetting.IpFilterMode, fetchSetting.DomainList, fetchSetting.IpList, fetchSetting.AllowedPorts, fetchSetting.ApplyIPFilterForDomain); err != nil { + return fmt.Errorf("redirect to %s blocked: %v", urlStr, err) + } + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + return nil +} + func InitHttpClient() { if common.RelayTimeout == 0 { - httpClient = &http.Client{} + httpClient = &http.Client{ + CheckRedirect: checkRedirect, + } } else { httpClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, + Timeout: time.Duration(common.RelayTimeout) * time.Second, + CheckRedirect: checkRedirect, } } } @@ -69,6 +85,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { Transport: &http.Transport{ Proxy: http.ProxyURL(parsedURL), }, + CheckRedirect: checkRedirect, } client.Timeout = time.Duration(common.RelayTimeout) * time.Second proxyClientLock.Lock() @@ -102,6 +119,7 @@ func NewProxyHttpClient(proxyURL string) (*http.Client, error) { return dialer.Dial(network, addr) }, }, + CheckRedirect: checkRedirect, } client.Timeout = time.Duration(common.RelayTimeout) * time.Second proxyClientLock.Lock()