diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 9540195cd..57f15bd33 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -98,6 +98,19 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s return headerOverride, nil } +func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) { + if req == nil { + return + } + for key, value := range headerOverride { + req.Header.Set(key, value) + // set Host in req + if strings.EqualFold(key, "Host") { + req.Host = value + } + } +} + func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(info) if err != nil { @@ -121,9 +134,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, err } - for key, value := range headerOverride { - headers.Set(key, value) - } + applyHeaderOverrideToRequest(req, headerOverride) resp, err := doRequest(c, req, info) if err != nil { return nil, fmt.Errorf("do request failed: %w", err) @@ -156,9 +167,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod if err != nil { return nil, err } - for key, value := range headerOverride { - headers.Set(key, value) - } + applyHeaderOverrideToRequest(req, headerOverride) resp, err := doRequest(c, req, info) if err != nil { return nil, fmt.Errorf("do request failed: %w", err)