mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 04:40:59 +00:00
Merge pull request #2840 from seefs001/feature/header-regex-override
feat: 支持基于Go Regex规则和全量的请求体透传
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -40,6 +41,86 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
|
||||
|
||||
const clientHeaderPlaceholderPrefix = "{client_header:"
|
||||
|
||||
const (
|
||||
headerPassthroughAllKey = "*"
|
||||
headerPassthroughRegexPrefix = "re:"
|
||||
headerPassthroughRegexPrefixV2 = "regex:"
|
||||
)
|
||||
|
||||
var passthroughSkipHeaderNamesLower = map[string]struct{}{
|
||||
// RFC 7230 hop-by-hop headers.
|
||||
"connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
|
||||
// Additional headers that should not be forwarded by name-matching passthrough rules.
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
|
||||
// Do not passthrough credentials by wildcard/regex.
|
||||
"authorization": {},
|
||||
"x-api-key": {},
|
||||
"x-goog-api-key": {},
|
||||
|
||||
// WebSocket handshake headers are generated by the client/dialer.
|
||||
"sec-websocket-key": {},
|
||||
"sec-websocket-version": {},
|
||||
"sec-websocket-extensions": {},
|
||||
}
|
||||
|
||||
var headerPassthroughRegexCache sync.Map // map[string]*regexp.Regexp
|
||||
|
||||
func getHeaderPassthroughRegex(pattern string) (*regexp.Regexp, error) {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
if pattern == "" {
|
||||
return nil, errors.New("empty regex pattern")
|
||||
}
|
||||
if v, ok := headerPassthroughRegexCache.Load(pattern); ok {
|
||||
if re, ok := v.(*regexp.Regexp); ok {
|
||||
return re, nil
|
||||
}
|
||||
headerPassthroughRegexCache.Delete(pattern)
|
||||
}
|
||||
compiled, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
actual, _ := headerPassthroughRegexCache.LoadOrStore(pattern, compiled)
|
||||
if re, ok := actual.(*regexp.Regexp); ok {
|
||||
return re, nil
|
||||
}
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
func isHeaderPassthroughRuleKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if key == headerPassthroughAllKey {
|
||||
return true
|
||||
}
|
||||
lower := strings.ToLower(key)
|
||||
return strings.HasPrefix(lower, headerPassthroughRegexPrefix) || strings.HasPrefix(lower, headerPassthroughRegexPrefixV2)
|
||||
}
|
||||
|
||||
func shouldSkipPassthroughHeader(name string) bool {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
lower := strings.ToLower(name)
|
||||
if _, ok := passthroughSkipHeaderNamesLower[lower]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
|
||||
trimmed := strings.TrimSpace(template)
|
||||
if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
|
||||
@@ -77,9 +158,85 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
|
||||
// Supported placeholders:
|
||||
// - {api_key}: resolved to the channel API key
|
||||
// - {client_header:<name>}: resolved to the incoming request header value
|
||||
//
|
||||
// Header passthrough rules (keys only; values are ignored):
|
||||
// - "*": passthrough all incoming headers by name (excluding unsafe headers)
|
||||
// - "re:<regex>" / "regex:<regex>": passthrough headers whose names match the regex (Go regexp)
|
||||
//
|
||||
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
|
||||
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
|
||||
headerOverride := make(map[string]string)
|
||||
|
||||
passAll := false
|
||||
var passthroughRegex []*regexp.Regexp
|
||||
for k := range info.HeadersOverride {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if key == headerPassthroughAllKey {
|
||||
passAll = true
|
||||
continue
|
||||
}
|
||||
|
||||
lower := strings.ToLower(key)
|
||||
var pattern string
|
||||
switch {
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if pattern == "" {
|
||||
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
compiled, err := getHeaderPassthroughRegex(pattern)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
passthroughRegex = append(passthroughRegex, compiled)
|
||||
}
|
||||
|
||||
if passAll || len(passthroughRegex) > 0 {
|
||||
if c == nil || c.Request == nil {
|
||||
return nil, types.NewError(fmt.Errorf("missing request context for header passthrough"), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
for name := range c.Request.Header {
|
||||
if shouldSkipPassthroughHeader(name) {
|
||||
continue
|
||||
}
|
||||
if !passAll {
|
||||
matched := false
|
||||
for _, re := range passthroughRegex {
|
||||
if re.MatchString(name) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
}
|
||||
value := strings.TrimSpace(c.Request.Header.Get(name))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
headerOverride[name] = value
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range info.HeadersOverride {
|
||||
if isHeaderPassthroughRuleKey(k) {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
@@ -93,7 +250,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
||||
continue
|
||||
}
|
||||
|
||||
headerOverride[k] = value
|
||||
headerOverride[key] = value
|
||||
}
|
||||
return headerOverride, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user