mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
feat: add header passthrough
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -40,6 +41,86 @@ func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Hea
|
||||
|
||||
const clientHeaderPlaceholderPrefix = "{client_header:"
|
||||
|
||||
const (
|
||||
headerPassthroughAllKey = "*"
|
||||
headerPassthroughRegexPrefix = "re:"
|
||||
headerPassthroughRegexPrefixV2 = "regex:"
|
||||
)
|
||||
|
||||
var passthroughSkipHeaderNamesLower = map[string]struct{}{
|
||||
// RFC 7230 hop-by-hop headers.
|
||||
"connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
|
||||
// Additional headers that should not be forwarded by name-matching passthrough rules.
|
||||
"host": {},
|
||||
"content-length": {},
|
||||
|
||||
// Do not passthrough credentials by wildcard/regex.
|
||||
"authorization": {},
|
||||
"x-api-key": {},
|
||||
"x-goog-api-key": {},
|
||||
|
||||
// WebSocket handshake headers are generated by the client/dialer.
|
||||
"sec-websocket-key": {},
|
||||
"sec-websocket-version": {},
|
||||
"sec-websocket-extensions": {},
|
||||
}
|
||||
|
||||
var headerPassthroughRegexCache sync.Map // map[string]*regexp.Regexp
|
||||
|
||||
func getHeaderPassthroughRegex(pattern string) (*regexp.Regexp, error) {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
if pattern == "" {
|
||||
return nil, errors.New("empty regex pattern")
|
||||
}
|
||||
if v, ok := headerPassthroughRegexCache.Load(pattern); ok {
|
||||
if re, ok := v.(*regexp.Regexp); ok {
|
||||
return re, nil
|
||||
}
|
||||
headerPassthroughRegexCache.Delete(pattern)
|
||||
}
|
||||
compiled, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
actual, _ := headerPassthroughRegexCache.LoadOrStore(pattern, compiled)
|
||||
if re, ok := actual.(*regexp.Regexp); ok {
|
||||
return re, nil
|
||||
}
|
||||
return compiled, nil
|
||||
}
|
||||
|
||||
func isHeaderPassthroughRuleKey(key string) bool {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
if key == headerPassthroughAllKey {
|
||||
return true
|
||||
}
|
||||
lower := strings.ToLower(key)
|
||||
return strings.HasPrefix(lower, headerPassthroughRegexPrefix) || strings.HasPrefix(lower, headerPassthroughRegexPrefixV2)
|
||||
}
|
||||
|
||||
func shouldSkipPassthroughHeader(name string) bool {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
lower := strings.ToLower(name)
|
||||
if _, ok := passthroughSkipHeaderNamesLower[lower]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) {
|
||||
trimmed := strings.TrimSpace(template)
|
||||
if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) {
|
||||
@@ -77,9 +158,85 @@ func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey str
|
||||
// Supported placeholders:
|
||||
// - {api_key}: resolved to the channel API key
|
||||
// - {client_header:<name>}: resolved to the incoming request header value
|
||||
//
|
||||
// Header passthrough rules (keys only; values are ignored):
|
||||
// - "*": passthrough all incoming headers by name (excluding unsafe headers)
|
||||
// - "re:<regex>" / "regex:<regex>": passthrough headers whose names match the regex (Go regexp)
|
||||
//
|
||||
// Passthrough rules are applied first, then normal overrides are applied, so explicit overrides win.
|
||||
func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) {
|
||||
headerOverride := make(map[string]string)
|
||||
|
||||
passAll := false
|
||||
var passthroughRegex []*regexp.Regexp
|
||||
for k := range info.HeadersOverride {
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if key == headerPassthroughAllKey {
|
||||
passAll = true
|
||||
continue
|
||||
}
|
||||
|
||||
lower := strings.ToLower(key)
|
||||
var pattern string
|
||||
switch {
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefix):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):])
|
||||
case strings.HasPrefix(lower, headerPassthroughRegexPrefixV2):
|
||||
pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):])
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
if pattern == "" {
|
||||
return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
compiled, err := getHeaderPassthroughRegex(pattern)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
passthroughRegex = append(passthroughRegex, compiled)
|
||||
}
|
||||
|
||||
if passAll || len(passthroughRegex) > 0 {
|
||||
if c == nil || c.Request == nil {
|
||||
return nil, types.NewError(fmt.Errorf("missing request context for header passthrough"), types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
}
|
||||
for name := range c.Request.Header {
|
||||
if shouldSkipPassthroughHeader(name) {
|
||||
continue
|
||||
}
|
||||
if !passAll {
|
||||
matched := false
|
||||
for _, re := range passthroughRegex {
|
||||
if re.MatchString(name) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
}
|
||||
value := strings.TrimSpace(c.Request.Header.Get(name))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
headerOverride[name] = value
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range info.HeadersOverride {
|
||||
if isHeaderPassthroughRuleKey(k) {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(k)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid)
|
||||
@@ -93,7 +250,7 @@ func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]s
|
||||
continue
|
||||
}
|
||||
|
||||
headerOverride[k] = value
|
||||
headerOverride[key] = value
|
||||
}
|
||||
return headerOverride, nil
|
||||
}
|
||||
|
||||
@@ -3113,6 +3113,28 @@ const EditChannelModal = (props) => {
|
||||
extraText={
|
||||
<div className='flex flex-col gap-1'>
|
||||
<div className='flex gap-2 flex-wrap items-center'>
|
||||
<Text
|
||||
className='!text-semi-color-primary cursor-pointer'
|
||||
onClick={() =>
|
||||
handleInputChange(
|
||||
'header_override',
|
||||
JSON.stringify(
|
||||
{
|
||||
'*': true,
|
||||
're:^X-Trace-.*$': true,
|
||||
'X-Foo': '{client_header:X-Foo}',
|
||||
Authorization: 'Bearer {api_key}',
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0',
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('填入模板')}
|
||||
</Text>
|
||||
<Text
|
||||
className='!text-semi-color-primary cursor-pointer'
|
||||
onClick={() =>
|
||||
@@ -3120,9 +3142,7 @@ const EditChannelModal = (props) => {
|
||||
'header_override',
|
||||
JSON.stringify(
|
||||
{
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36 Edg/139.0.0.0',
|
||||
Authorization: 'Bearer{api_key}',
|
||||
'*': true,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
@@ -3130,7 +3150,7 @@ const EditChannelModal = (props) => {
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('填入模板')}
|
||||
{t('填入透传模版')}
|
||||
</Text>
|
||||
<Text
|
||||
className='!text-semi-color-primary cursor-pointer'
|
||||
|
||||
@@ -739,6 +739,8 @@
|
||||
"填入": "Fill",
|
||||
"填入所有模型": "Fill in all models",
|
||||
"填入模板": "Fill Template",
|
||||
"填入透传模版": "Fill Passthrough Template",
|
||||
"填入透传完整模版": "Fill Full Passthrough Template",
|
||||
"填入相关模型": "Fill Related Models",
|
||||
"填写Gotify服务器的完整URL地址": "Fill in the complete URL address of the Gotify server",
|
||||
"填写带https的域名,逗号分隔": "Fill in domains with https, separated by commas",
|
||||
|
||||
@@ -744,6 +744,8 @@
|
||||
"填入": "Remplir",
|
||||
"填入所有模型": "Remplir tous les modèles",
|
||||
"填入模板": "Remplir le modèle",
|
||||
"填入透传模版": "Remplir le modèle passthrough",
|
||||
"填入透传完整模版": "Remplir le modèle passthrough complet",
|
||||
"填入相关模型": "Remplir les modèles associés",
|
||||
"填写Gotify服务器的完整URL地址": "Remplir l'adresse URL complète du serveur Gotify",
|
||||
"填写带https的域名,逗号分隔": "Saisir les domaines avec https, séparés par des virgules",
|
||||
|
||||
@@ -735,6 +735,8 @@
|
||||
"填入": "入力",
|
||||
"填入所有模型": "すべてのモデルを入力",
|
||||
"填入模板": "テンプレートを入力",
|
||||
"填入透传模版": "パススルーテンプレートを入力",
|
||||
"填入透传完整模版": "完全なパススルーテンプレートを入力",
|
||||
"填入相关模型": "関連モデルを入力",
|
||||
"填写Gotify服务器的完整URL地址": "Gotifyサーバーの完全なURLを入力してください",
|
||||
"填写带https的域名,逗号分隔": "https://を含むドメインをカンマ区切りで入力してください",
|
||||
|
||||
@@ -750,6 +750,8 @@
|
||||
"填入": "Заполнить",
|
||||
"填入所有模型": "Заполнить все модели",
|
||||
"填入模板": "Заполнить шаблон",
|
||||
"填入透传模版": "Заполнить шаблон passthrough",
|
||||
"填入透传完整模版": "Заполнить полный шаблон passthrough",
|
||||
"填入相关模型": "Заполнить связанные модели",
|
||||
"填写Gotify服务器的完整URL地址": "Введите полный URL-адрес сервера Gotify",
|
||||
"填写带https的域名,逗号分隔": "Введите домены с https, разделённые запятыми",
|
||||
|
||||
@@ -736,6 +736,8 @@
|
||||
"填入": "Điền",
|
||||
"填入所有模型": "Điền tất cả mô hình",
|
||||
"填入模板": "Điền mẫu",
|
||||
"填入透传模版": "Điền mẫu truyền qua",
|
||||
"填入透传完整模版": "Điền mẫu truyền qua đầy đủ",
|
||||
"填入相关模型": "Điền mô hình liên quan",
|
||||
"填写Gotify服务器的完整URL地址": "Điền địa chỉ URL đầy đủ của máy chủ Gotify",
|
||||
"填写带https的域名,逗号分隔": "Điền tên miền có https, phân tách bằng dấu phẩy",
|
||||
|
||||
@@ -735,6 +735,8 @@
|
||||
"填入": "填入",
|
||||
"填入所有模型": "填入所有模型",
|
||||
"填入模板": "填入模板",
|
||||
"填入透传模版": "填入透传模版",
|
||||
"填入透传完整模版": "填入透传完整模版",
|
||||
"填入相关模型": "填入相关模型",
|
||||
"填写Gotify服务器的完整URL地址": "填写Gotify服务器的完整URL地址",
|
||||
"填写带https的域名,逗号分隔": "填写带https的域名,逗号分隔",
|
||||
|
||||
Reference in New Issue
Block a user