Merge pull request #2840 from seefs001/feature/header-regex-override

feat: 支持基于Go Regex规则和全量的请求体透传
This commit is contained in:
Calcium-Ion
2026-02-05 01:56:27 +08:00
committed by GitHub
8 changed files with 194 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"
"sync"
"time"
@@ -40,6 +41,86 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
const clientHeaderPlaceholderPrefix = "{client_header:"
const (
headerPassthroughAllKey = "*"
headerPassthroughRegexPrefix = "re:"
headerPassthroughRegexPrefixV2 = "regex:"
)
var passthroughSkipHeaderNamesLower = map[string]struct{}{
// RFC 7230 hop-by-hop headers.
"connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"transfer-encoding": {},
"upgrade": {},
// Additional headers that should not be forwarded by name-matching passthrough rules.
"host": {},
"content-length": {},
// Do not passthrough credentials by wildcard/regex.
"authorization": {},
"x-api-key": {},
"x-goog-api-key": {},
// WebSocket handshake headers are generated by the client/dialer.
"sec-websocket-key": {},
"sec-websocket-version": {},
"sec-websocket-extensions": {},
}
var headerPassthroughRegexCache sync.Map // map[string]*regexp.Regexp
func getHeaderPassthroughRegex(pattern string) (*regexp.Regexp, error) {
pattern = strings.TrimSpace(pattern)
if pattern == "" {
return nil, errors.New("empty regex pattern")
}
if v, ok := headerPassthroughRegexCache.Load(pattern); ok {
if re, ok := v.(*regexp.Regexp); ok {
return re, nil
}
headerPassthroughRegexCache.Delete(pattern)
}
compiled, err := regexp.Compile(pattern)
if err != nil {
return nil, err
}
actual, _ := headerPassthroughRegexCache.LoadOrStore(pattern, compiled)
if re, ok := actual.(*regexp.Regexp); ok {
return re, nil
}
return compiled, nil
}
func isHeaderPassthroughRuleKey(key string) bool {
key = strings.TrimSpace(key)
if key == "" {
return false
}
if key == headerPassthroughAllKey {
return true
}
lower := strings.ToLower(key)
return strings.HasPrefix(lower, headerPassthroughRegexPrefix) || strings.HasPrefix(lower, headerPassthroughRegexPrefixV2)
}
func shouldSkipPassthroughHeader(name string) bool {
name = strings.TrimSpace(name)
if name == "" {
return true
}
lower := strings.ToLower(name)
if _, ok := passthroughSkipHeaderNamesLower[lower]; ok {
return true
}
return false
}
func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
trimmed := strings.TrimSpace(template)
if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
@@ -77,9 +158,85 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
// Supported placeholders:
// - {api_key}: resolved to the channel API key
// - {client_header:<name>}: resolved to the incoming request header value
//
// Header passthrough rules (keys only; values are ignored):
// - "*": passthrough all incoming headers by name (excluding unsafe headers)
// - "re:<regex>" / "regex:<regex>": passthrough headers whose names match the regex (Go regexp)
//
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
headerOverride := make(map[string]string)
passAll := false
var passthroughRegex []*regexp.Regexp
for k := range info.HeadersOverride {
key := strings.TrimSpace(k)
if key == "" {
continue
}
if key == headerPassthroughAllKey {
passAll = true
continue
}
lower := strings.ToLower(key)
var pattern string
switch {
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
default:
continue
}
if pattern == "" {
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
}
compiled, err := getHeaderPassthroughRegex(pattern)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
}
passthroughRegex = append(passthroughRegex, compiled)
}
if passAll || len(passthroughRegex) > 0 {
if c == nil || c.Request == nil {
return nil, types.NewError(fmt.Errorf("missing request context for header passthrough"), types.ErrorCodeChannelHeaderOverrideInvalid)
}
for name := range c.Request.Header {
if shouldSkipPassthroughHeader(name) {
continue
}
if !passAll {
matched := false
for _, re := range passthroughRegex {
if re.MatchString(name) {
matched = true
break
}
}
if !matched {
continue
}
}
value := strings.TrimSpace(c.Request.Header.Get(name))
if value == "" {
continue
}
headerOverride[name] = value
}
}
for k, v := range info.HeadersOverride {
if isHeaderPassthroughRuleKey(k) {
continue
}
key := strings.TrimSpace(k)
if key == "" {
continue
}
str, ok := v.(string)
if !ok {
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
@@ -93,7 +250,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
continue
}
headerOverride[k] = value
headerOverride[key] = value
}
return headerOverride, nil
}